Skip to content

perf(gdn): optimize MTP kernel with ILP rows and SMEM v caching#2618

Merged
aleozlx merged 4 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/improve_gdn_mtp_fp32state
Mar 7, 2026
Merged

perf(gdn): optimize MTP kernel with ILP rows and SMEM v caching#2618
aleozlx merged 4 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/improve_gdn_mtp_fp32state

Conversation

@ameynaik-hub
Copy link
Contributor

@ameynaik-hub ameynaik-hub commented Feb 22, 2026

Improve GDN MTP (Multi-Token Processing) kernel performance by:

  • Add instruction-level parallelism (ILP) with 1/2/4/8-row processing
  • Enable shared memory caching for V tensor tiles
  • Dynamic dispatch based on batch size and sequence length
  • Change default: disable_state_update=False (h always updated)
  • Keep original kernel for BS < 8 to avoid regression

This commit adds the optimized kernel while preserving the original kernel
for small batch sizes. The dispatch logic selects:

  • BS < 8: Original kernel (no ILP overhead, matches baseline performance)
  • BS >= 8: Optimized kernel with ILP rows and SMEM v caching

Performance (Qwen3-Next: q=k=16, v=32, d=128, FP32 state, cache ON,state update ON, 1000 iterations):

Previous Kernel Times (us):

BS T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 5.70 7.10 8.42 9.92 11.07 12.67 14.05
2 6.82 8.26 10.11 11.90 13.39 15.17 16.90
4 10.21 12.70 14.56 22.50 19.10 29.34 32.86
8 16.26 20.41 30.27 35.33 40.83 45.79 51.94
16 28.70 42.88 52.21 61.31 70.75 82.91 90.05
32 61.79 82.08 104.83 122.94 148.58 162.91 184.38
64 105.41 137.47 167.49 201.09 233.28 266.94 301.44
128 192.58 241.41 296.42 352.03 409.98 471.23 532.85
256 372.38 472.18 582.35 703.45 829.63 958.45 1094.30
512 744.24 963.50 1188.12 1439.23 1689.53 1954.30 2231.74

New Kernel Times (us):

BS T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 5.70 7.14 8.42 9.92 11.07 12.70 14.08
2 6.82 8.26 10.14 11.87 13.41 15.20 16.90
4 10.21 12.70 14.56 22.46 19.07 29.34 32.86
8 14.85 19.23 30.24 26.43 31.20 34.24 38.18
16 25.21 32.10 52.16 61.41 70.82 83.01 90.16
32 50.14 63.36 73.47 85.73 99.52 113.76 127.76
64 88.32 116.48 141.36 165.18 190.11 214.72 238.29
128 169.54 217.57 264.86 312.86 360.96 412.00 458.97
256 331.52 425.28 519.10 617.12 715.73 822.96 926.97
512 658.56 846.64 1036.64 1235.76 1443.53 1660.64 1872.56

Speedup vs Previous FlashInfer Kernel:

BS T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 1.00 0.99 1.00 1.00 1.00 1.00 1.00
2 1.00 1.00 1.00 1.00 1.00 1.00 1.00
4 1.00 1.00 1.00 1.00 1.00 1.00 1.00
8 1.09 1.06 1.00 1.34 1.31 1.34 1.36
16 1.14 1.34 1.00 1.00 1.00 1.00 1.00
32 1.23 1.30 1.43 1.43 1.49 1.43 1.44
64 1.19 1.18 1.18 1.22 1.23 1.24 1.27
128 1.14 1.11 1.12 1.13 1.14 1.14 1.16
256 1.12 1.11 1.12 1.14 1.16 1.16 1.18
512 1.13 1.14 1.15 1.16 1.17 1.18 1.19

Small batch sizes (BS < 8): 1.00x (no regression).
Large batch sizes (BS >= 8): 1.00x-1.49x improvement, avg ~1.20x.
All 70 correctness tests pass (10 BS x 7 T values).

Also adds --update-state flag to bench_gdn_decode.py to test with
disable_state_update=False (h output updated after each chunk).

this is with MTP setting cache enabled but h update disabled so that init state is not overwritten. cache ON, state update OFF

Main Branch Kernel Times (µs)

BS \ T 2 3 4 5 6 7 8
1 5.70 7.04 8.29 9.76 10.94 12.53 13.98
2 6.56 8.13 9.86 11.62 13.15 14.91 16.61
4 9.57 11.78 14.05 16.18 18.53 28.72 23.65
8 14.24 17.79 28.16 33.23 38.90 44.06 49.71
16 23.07 38.29 46.70 55.87 64.50 74.48 84.77
32 41.60 67.65 84.58 101.52 121.17 136.34 154.26
64 77.60 113.41 142.64 175.58 205.82 236.59 274.37
128 148.74 206.43 260.35 317.28 371.31 430.32 490.70
256 281.86 403.75 503.42 617.65 738.83 862.19 998.87
512 549.49 809.55 1018.71 1276.85 1525.92 1794.26 2057.52

Optimized Kernel Times (µs)

BS \ T 2 3 4 5 6 7 8
1 5.66 7.04 8.27 9.76 10.94 12.54 13.92
2 6.56 8.19 9.81 11.62 13.18 14.94 16.61
4 9.57 11.84 13.98 16.18 18.62 28.69 23.62
8 14.08 17.60 28.16 24.96 29.34 32.99 37.74
16 22.94 27.97 46.82 44.00 64.48 74.10 85.23
32 40.34 55.78 64.03 76.54 89.34 105.38 119.97
64 68.48 92.61 119.52 141.63 166.59 194.02 218.59
128 129.47 176.83 223.17 269.20 317.36 369.70 418.94
256 250.50 341.54 432.88 524.29 620.11 723.78 823.74
512 492.62 671.73 854.71 1041.25 1245.30 1458.29 1693.33

Speedup (main_time / optimized_time)

Values > 1.0 = optimized kernel is faster

BS \ T 2 3 4 5 6 7 8 Avg
1 1.01 1.00 1.00 1.00 1.00 1.00 1.00 1.00
2 1.00 0.99 1.01 1.00 1.00 1.00 1.00 1.00
4 1.00 0.99 1.01 1.00 1.00 1.00 1.00 1.00
8 1.01 1.01 1.00 1.33 1.33 1.34 1.32 1.19
16 1.01 1.37 1.00 1.27 1.00 1.01 0.99 1.09
32 1.03 1.21 1.32 1.33 1.36 1.29 1.29 1.26
64 1.13 1.22 1.19 1.24 1.24 1.22 1.26 1.21
128 1.15 1.17 1.17 1.18 1.17 1.16 1.17 1.17
256 1.13 1.18 1.16 1.18 1.19 1.19 1.21 1.18
512 1.12 1.21 1.19 1.23 1.23 1.23 1.22 1.20

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • New public GDN decode APIs with runtime backend selection, flexible state layout/dtype handling, and an MTP option controlling state updates (default behavior changed).
  • Tests
    • Added tests exercising MTP path with FP32 state, cache enabled, and state-update enabled.
  • Benchmarks
    • CLI and benchmarks gain an --update-state flag and report/update-state behavior during MTP runs.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 22, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Refactors GDN decode to a backend-driven API: exposes three public decode functions, moves kernel implementations to flashinfer/gdn_kernels, replaces in-file kernel selection with runtime backend flags and run_* entrypoints, adds MTP tile/vec helpers, and propagates a disable_state_update option into tests and benchmarks.

Changes

Cohort / File(s) Summary
API & entrypoints
flashinfer/gdn_decode.py
Added public wrappers gated_delta_rule_decode_pretranspose, gated_delta_rule_decode, gated_delta_rule_mtp; replaced in-file kernel configs with runtime backend flags and conditional imports; added run_* delegations and TILE_V constant; support for multiple state layouts and dtype casting; removed legacy compile paths.
Kernel implementations
flashinfer/gdn_kernels/..., flashinfer/gdn_kernels/gdn_decode_pretranspose.py, flashinfer/gdn_kernels/gdn_decode_nontranspose.py
New backend modules exposing run_pretranspose_decode, run_nontranspose_decode, run_mtp_decode, plus get_tile_v_mtp/get_vec_size_mtp; large CuTe-based kernels for pretranspose and nontranspose paths, launcher wrappers, and compilation/caching helpers.
Kernel exports
flashinfer/gdn_kernels/__init__.py
Optional imports of backend run functions and helpers with try/except guards; __all__ expanded to include new run_* and MTP helpers.
Tests
tests/gdn/test_decode_delta_rule.py
Extended _test_verify_kernel_mtp to accept disable_state_update, passed through to gated_delta_rule_mtp; added conditional final-state comparison when state updates are enabled; new test test_mtp_fp32_state_with_cache_and_state_update.
Benchmarks & CLI
benchmarks/bench_gdn_decode.py
Threaded disable_state_update into bench logic; added --update-state CLI flag and printed header output reflecting update-state and cache settings; default MTP bench calls now pass state-update flag.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host/API
    participant API as gdn_decode API
    participant Kernels as gdn_kernels (module)
    participant Cache as Kernel Cache/Compiler
    participant GPU as Compiled Kernel / Device
    participant State as Output / State Buffers

    Host->>API: call gated_delta_rule_* (inputs, disable_state_update)
    API->>Kernels: select backend (pretranspose/nontranspose/mtp) based on flags
    Kernels->>Cache: request compiled kernel / get_tile_v_mtp/get_vec_size_mtp
    Cache->>Kernels: compiled kernel or compile & cache
    Kernels->>GPU: launch compiled kernel (params, tile/vec, disable_state_update)
    GPU->>State: compute outputs, conditional state writeback
    State->>Host: output tensors and (optionally) updated state returned
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

run-ci, ready

Suggested reviewers

  • yzh119
  • cyx-6
  • bkryu
  • nvmbreughe
  • jimmyzho
  • kahyunnam
  • djmmoss

Poem

🐰 I hop through kernels, tidy tiles in rows,
I stash V in memory where softplus glows,
I switch backends at runtime, choose vec and tile,
I copy states and outputs with a careful smile,
A rabbit cheers — compute done in style!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main performance optimization: adding ILP rows and SMEM v caching to the MTP kernel, which aligns with the core changes.
Description check ✅ Passed The PR description is comprehensive, including detailed performance metrics, benchmark tables, test results, and explanation of the dispatch logic. It follows the template structure with description, implementation details, and test coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ameynaik-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the Gated Delta Rule Multi-Token Processing (MTP) kernel by leveraging advanced GPU programming techniques. The enhancements focus on improving performance across a wide range of batch sizes and sequence lengths through fine-grained control over instruction-level parallelism and strategic use of shared memory for data caching. The changes also include dynamic configuration selection to adapt to varying workload characteristics, resulting in substantial speedups as demonstrated by the provided benchmarks.

Highlights

  • Instruction-Level Parallelism (ILP): Implemented dynamic instruction-level parallelism (ILP) for the GDN MTP kernel, supporting 1, 2, 4, or 8-row processing to improve throughput.
  • Shared Memory (SMEM) Caching for V Tensor: Enabled shared memory caching for V tensor tiles, reducing global memory access latency and improving performance.
  • Dynamic Kernel Dispatch: Introduced dynamic dispatch logic to select optimal kernel configurations (ILP rows, SMEM usage) based on runtime batch size and sequence length, adapting to different workload characteristics.
  • Default State Update Behavior: Changed the default behavior of disable_state_update to False, ensuring the h state is always updated by default in the MTP kernel.
  • New SMEM-Resident Kernel Variant: Added a new SMEM-resident state MTP kernel variant designed for higher occupancy, though it is currently disabled as benchmarks showed it to be slower than the optimized register-resident kernel.
  • Comprehensive MTP Tests: Added extensive test cases covering all batch sizes and sequence lengths with state updates enabled, ensuring the correctness of the optimized MTP kernel.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_decode.py
    • Updated comment for NUM_THREADS_MTP.
    • Modified get_tile_v_mtp function to refine V-tile selection for large batches and added new helper functions get_use_2row_ilp, get_ilp_rows, and get_use_smem_v for dynamic kernel parameter selection.
    • Reordered variable declarations (r_A_log, r_a, r_dt_bias, r_b) in gdn_decode_kernel_small_batch_pretranspose to enable early loading.
    • Removed conditional kernel selection, now always using run_gdn_decode_kernel_small_batch_pretranspose.
    • Extended gdn_verify_kernel_mtp signature to include ilp_rows and use_smem_v parameters.
    • Allocated sVdata and sOutput in shared memory within gdn_verify_kernel_mtp for V tensor caching and output accumulation.
    • Added multiple register arrays (r_h2 through r_h8) to support 8-row instruction-level parallelism in gdn_verify_kernel_mtp.
    • Implemented distinct code paths within gdn_verify_kernel_mtp for 8-row, 4-row, 2-row, and 1-row ILP, incorporating SMEM v preloading and cooperative output writeback.
    • Updated run_gdn_verify_kernel_mtp to accept and pass the new ilp_rows and use_smem_v parameters.
    • Adjusted shared memory byte calculation in run_gdn_verify_kernel_mtp to account for new shared memory allocations.
    • Introduced a new gdn_verify_kernel_mtp_smem kernel and its corresponding run_gdn_verify_kernel_mtp_smem and _get_compiled_mtp_smem_kernel functions for an SMEM-resident state approach, and changed the default disable_state_update in gated_delta_rule_mtp.
    • Modified gated_delta_rule_mtp to dynamically select between the original and the new SMEM-resident kernels, and to pass ILP and SMEM V parameters to the original kernel, with the SMEM-resident kernel currently disabled.
  • tests/gdn/test_decode_delta_rule.py
    • Added disable_state_update parameter to _test_verify_kernel_mtp for more flexible testing.
    • Modified the test to compare the final state when disable_state_update is False.
    • Introduced a new comprehensive test suite test_mtp_fp32_state_with_cache_and_state_update to validate the optimized MTP kernel across a wide range of batch sizes and sequence lengths with state updates enabled.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request significantly improves the performance of the GDN MTP (Multi-Token Processing) kernel by introducing instruction-level parallelism (ILP) with 1/2/4/8-row processing and shared memory caching for V tensor tiles. The dynamic dispatch logic based on batch size and sequence length is well-implemented, and the early reading of gate values in the decode kernel effectively hides memory latency. However, there are some discrepancies between the comments and the code regarding the optimal tile_v for large batches, and several unused functions and kernel variants remain in the codebase. Additionally, the change in the default value of disable_state_update is a breaking change that should be noted.

Comment on lines +132 to +142
V2: With runtime V-row loop, tile_v=128 is feasible for large batches.
This halves the grid size (num_v_tiles=1 vs 2), improving L2 cache behavior.
"""
if batch_size <= 2:
return 4 # Small batch needs max parallelism
elif batch_size <= 4:
return 8
if batch_size <= 4:
return 8 # Minimum 8 for 2-row ILP (rows_per_group must be >= 2)
elif batch_size <= 8:
return 16
elif batch_size <= 16:
return 32
else:
return 64
return 64 # BS>=17: 4 groups × 16 rows = 64 rows per tile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a discrepancy between the comments and the implementation of get_tile_v_mtp. The comment at line 132 states that tile_v=128 is feasible for large batches and improves L2 cache behavior by halving the grid size. However, the code returns a maximum of 64 for all batch sizes greater than 16. If tile_v=128 was intended for large batches (as also suggested by the comment at line 189), the function should be updated to return 128 for batch_size >= 64.

Suggested change
V2: With runtime V-row loop, tile_v=128 is feasible for large batches.
This halves the grid size (num_v_tiles=1 vs 2), improving L2 cache behavior.
"""
if batch_size <= 2:
return 4 # Small batch needs max parallelism
elif batch_size <= 4:
return 8
if batch_size <= 4:
return 8 # Minimum 8 for 2-row ILP (rows_per_group must be >= 2)
elif batch_size <= 8:
return 16
elif batch_size <= 16:
return 32
else:
return 64
return 64 # BS>=17: 4 groups × 16 rows = 64 rows per tile
V2: With runtime V-row loop, tile_v=128 is feasible for large batches.
This halves the grid size (num_v_tiles=1 vs 2), improving L2 cache behavior.
"""
if batch_size <= 4:
return 8 # Minimum 8 for 2-row ILP (rows_per_group must be >= 2)
elif batch_size <= 8:
return 16
elif batch_size <= 16:
return 32
elif batch_size <= 32:
return 64
else:
return 128

Comment on lines +2215 to +2235
r_h2 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h3 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h4 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h5 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h6 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h7 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h8 = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The register tensors r_h2 through r_h8 are unconditionally allocated at the start of the kernel. While the compiler might optimize them away if not used, it is better practice to allocate them conditionally based on the ilp_rows constexpr to ensure minimal register pressure for the 1-row or 2-row paths.

scale: Optional[float] = None,
output: Optional[torch.Tensor] = None,
intermediate_states_buffer: Optional[torch.Tensor] = None,
disable_state_update: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing the default value of disable_state_update from True to False is a breaking change for the gated_delta_rule_mtp API. Users who relied on the previous default (where the state was not updated during verification) will now have their initial states modified in-place. This should be clearly documented or communicated.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/gdn/test_decode_delta_rule.py (1)

614-627: ⚠️ Potential issue | 🔴 Critical

Bug: seed is passed into the disable_state_update parameter position.

After inserting disable_state_update at position 12 in _test_verify_kernel_mtp (line 418), the existing call site still passes seed as the 12th positional arg (line 627). This means:

  • seed (e.g., 0 or 42) silently becomes disable_state_update
  • The actual seed kwarg defaults to 0, ignoring the SEED env var

With SEED=0, disable_state_update becomes False (truthy check), which accidentally enables state updates in a test that was not designed to verify them.

Fix: pass disable_state_update explicitly as keyword arg
     _test_verify_kernel_mtp(
         dtype,
         batch_size,
         num_q_heads,
         num_k_heads,
         num_v_heads,
         head_size,
         seq_len,
         scale_val,
         alpha,
         beta,
         cache_intermediate_states,
+        seed=seed,
-        seed,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 614 - 627, The call to
_test_verify_kernel_mtp is passing seed as the 12th positional argument which
now maps to the new parameter disable_state_update; update the call site so
disable_state_update is passed explicitly as a keyword (e.g.,
disable_state_update=False or the intended variable) and ensure seed is passed
via the seed=... keyword so the _test_verify_kernel_mtp(dtype, batch_size, ...,
alpha, beta, cache_intermediate_states, disable_state_update=<value>, seed=seed)
signature receives the correct values.
🧹 Nitpick comments (5)
flashinfer/gdn_decode.py (5)

173-222: get_ilp_rows returns 8 when the guard only checks rows_per_group >= 4 — should verify >= 8.

Line 199 checks if tile_v >= 64 and rows_per_group >= 4: before potentially returning 8 (line 204). While tile_v >= 64 guarantees rows_per_group >= 16 today (since num_groups = 4), the explicit guard of >= 4 is misleadingly weak for an 8-row path. If num_groups ever changes, this could silently allow eighth_rows = 0.

Tighten the guard
-    if tile_v >= 64 and rows_per_group >= 4:
+    if tile_v >= 64 and rows_per_group >= 8:
         # State update ON + low T: 8-row ILP maximizes memory parallelism
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 173 - 222, The guard in get_ilp_rows
currently allows taking the 8-row ILP path when tile_v >= 64 but only checks
rows_per_group >= 4; tighten this to require rows_per_group >= 8 (or
equivalently tile_v >= 32*num_groups) before returning 8 to ensure there are
actually eight rows per group if num_groups changes—update the conditional at
the block that computes tile_v/rows_per_group (and the nested branch that
returns 8 when not disable_state_update and seq_len <= 2) to use rows_per_group
>= 8 (or a stronger tile_v check) so the 8-row path is only taken when enough
rows exist.

2349-2353: Latent correctness risk: 8-row ILP path silently drops rows if rows_per_group is not a multiple of 8.

eighth_rows = rows_per_group // 8 truncates, so if rows_per_group were e.g. 12, only 8 of 12 rows per group would be processed and the remaining 4 would never update state or produce output. The current get_tile_v_mtp values (8, 16, 32, 64) yield rows_per_group ∈ {2, 4, 8, 16}, which are all safe. But if a new tile_v is added (e.g., 48 → rows_per_group=12), this would silently corrupt results.

Consider adding a compile-time assert (or a comment) documenting that rows_per_group must be divisible by ilp_rows.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2349 - 2353, The 8-row ILP path uses
integer division for eighth_rows = rows_per_group // 8 which will silently drop
leftover rows if rows_per_group is not divisible by ilp_rows; add a compile-time
assertion (or at minimum a clear comment) ensuring rows_per_group % ilp_rows ==
0 before the cutlass.const_expr(ilp_rows == 8) branch so the code fails to
compile if a new tile_v makes rows_per_group non-multiple of 8; reference the
symbols ilp_rows, rows_per_group, eighth_rows and the get_tile_v_mtp/get_tile_v
computation so the check is placed immediately before the for row_oct in
cutlass.range_constexpr(eighth_rows) loop.

2215-2235: Eight register tensors are always declared, even for ilp_rows=1.

r_h2 through r_h8 are allocated regardless of ilp_rows. Since ilp_rows is Constexpr, the compiler should dead-code-eliminate the unused ILP branches and their register usage. However, if the CUTLASS compiler doesn't fully optimize this, it would unnecessarily increase register pressure for low-ILP configs.

Worth confirming the compiler handles this correctly — if not, consider guarding allocations with cutlass.const_expr.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2215 - 2235, r_h2 through r_h8 are
always allocated via cute.make_rmem_tensor which raises register pressure for
low-ILP configs; wrap the allocations in a compile-time conditional using the
Constexpr ilp_rows (e.g. use cutlass.const_expr or equivalent) so you only
create the extra r_hN tensors when ilp_rows > N (allocate up to ilp_rows entries
instead of unconditionally creating r_h2..r_h8), keeping the same layout/Float32
creation calls inside the guarded branches.

1212-1213: Always using the 8-CTA ("small batch") architecture for the pretranspose path is a behavioral change.

The comment says benchmarks show it's better for all batch sizes, but this means gdn_decode_kernel_big_batch_pretranspose and run_gdn_decode_kernel_big_batch_pretranspose are now dead code. Consider removing them or adding a comment noting they're retained for reference.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1212 - 1213, The code unconditionally
assigns run_func = run_gdn_decode_kernel_small_batch_pretranspose which makes
the big-batch implementations gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_big_batch_pretranspose dead code; either remove those
unused big-batch functions and related helpers or explicitly mark them as
deprecated/kept-for-reference and update the inline comment to explain why the
small-batch 8-CTA path is forced for all batch sizes. Locate the assignment to
run_func and then delete the big-batch functions (and any callers/exports) or
add a clear comment above their definitions and at the assignment site stating
they are retained for historical/reference purposes and will not be used, so
future readers/linters understand they are intentionally unused.

3796-3797: SMEM kernel path is hard-disabled (use_smem_kernel = False) — dead code.

The entire SMEM-resident kernel (gdn_verify_kernel_mtp_smem, run_gdn_verify_kernel_mtp_smem, _get_compiled_mtp_smem_kernel) and the branch at lines 3799-3891 are unreachable. If this is intentionally left for future experimentation, a brief comment explaining the plan would help. Otherwise, this is a significant amount of dead code.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 3796 - 3797, The SMEM kernel path is
hard-disabled by setting use_smem_kernel = False which makes
gdn_verify_kernel_mtp_smem, run_gdn_verify_kernel_mtp_smem,
_get_compiled_mtp_smem_kernel and the branch (currently unreachable around lines
3799-3891) dead code; either remove the unused SMEM functions/branch to
eliminate dead code, or make the path selectable (expose use_smem_kernel as a
runtime/config flag or environment option and wire it into the existing
kernel-selection logic), and if you intend to keep it for future experiments add
a concise TODO comment above use_smem_kernel explaining why it’s disabled and
when/how it should be re-enabled (reference symbols: use_smem_kernel,
gdn_verify_kernel_mtp_smem, run_gdn_verify_kernel_mtp_smem,
_get_compiled_mtp_smem_kernel).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 3680-3681: The docstring for the parameter disable_state_update is
out of sync with the function signature: the signature sets disable_state_update
default to False but the docstring still states "Default: True"; update the
docstring text where disable_state_update is documented (around the parameter
docs near lines ~3716-3718) to state "Default: False" (or otherwise match the
new behavior) so the documented default matches the signature and behavior of
disable_state_update.
- Around line 3643-3663: The cache key tuple is passing (…, ilp_rows,
use_smem_v) but _get_compiled_mtp_kernel currently has a stale use_2row_ilp
parameter before ilp_rows, causing positional misbinds; fix by aligning the
function signature to match the cache key: replace the legacy use_2row_ilp
parameter with use_smem_v: bool (or reorder parameters so ilp_rows and
use_smem_v occupy the last two positions in _get_compiled_mtp_kernel) and keep
ilp_rows: int as the final int parameter; update the default for use_smem_v if
needed and remove or mark use_2row_ilp as deprecated so positional unpacking
from the cache tuple is correct.

---

Outside diff comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 614-627: The call to _test_verify_kernel_mtp is passing seed as
the 12th positional argument which now maps to the new parameter
disable_state_update; update the call site so disable_state_update is passed
explicitly as a keyword (e.g., disable_state_update=False or the intended
variable) and ensure seed is passed via the seed=... keyword so the
_test_verify_kernel_mtp(dtype, batch_size, ..., alpha, beta,
cache_intermediate_states, disable_state_update=<value>, seed=seed) signature
receives the correct values.

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 173-222: The guard in get_ilp_rows currently allows taking the
8-row ILP path when tile_v >= 64 but only checks rows_per_group >= 4; tighten
this to require rows_per_group >= 8 (or equivalently tile_v >= 32*num_groups)
before returning 8 to ensure there are actually eight rows per group if
num_groups changes—update the conditional at the block that computes
tile_v/rows_per_group (and the nested branch that returns 8 when not
disable_state_update and seq_len <= 2) to use rows_per_group >= 8 (or a stronger
tile_v check) so the 8-row path is only taken when enough rows exist.
- Around line 2349-2353: The 8-row ILP path uses integer division for
eighth_rows = rows_per_group // 8 which will silently drop leftover rows if
rows_per_group is not divisible by ilp_rows; add a compile-time assertion (or at
minimum a clear comment) ensuring rows_per_group % ilp_rows == 0 before the
cutlass.const_expr(ilp_rows == 8) branch so the code fails to compile if a new
tile_v makes rows_per_group non-multiple of 8; reference the symbols ilp_rows,
rows_per_group, eighth_rows and the get_tile_v_mtp/get_tile_v computation so the
check is placed immediately before the for row_oct in
cutlass.range_constexpr(eighth_rows) loop.
- Around line 2215-2235: r_h2 through r_h8 are always allocated via
cute.make_rmem_tensor which raises register pressure for low-ILP configs; wrap
the allocations in a compile-time conditional using the Constexpr ilp_rows (e.g.
use cutlass.const_expr or equivalent) so you only create the extra r_hN tensors
when ilp_rows > N (allocate up to ilp_rows entries instead of unconditionally
creating r_h2..r_h8), keeping the same layout/Float32 creation calls inside the
guarded branches.
- Around line 1212-1213: The code unconditionally assigns run_func =
run_gdn_decode_kernel_small_batch_pretranspose which makes the big-batch
implementations gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_big_batch_pretranspose dead code; either remove those
unused big-batch functions and related helpers or explicitly mark them as
deprecated/kept-for-reference and update the inline comment to explain why the
small-batch 8-CTA path is forced for all batch sizes. Locate the assignment to
run_func and then delete the big-batch functions (and any callers/exports) or
add a clear comment above their definitions and at the assignment site stating
they are retained for historical/reference purposes and will not be used, so
future readers/linters understand they are intentionally unused.
- Around line 3796-3797: The SMEM kernel path is hard-disabled by setting
use_smem_kernel = False which makes gdn_verify_kernel_mtp_smem,
run_gdn_verify_kernel_mtp_smem, _get_compiled_mtp_smem_kernel and the branch
(currently unreachable around lines 3799-3891) dead code; either remove the
unused SMEM functions/branch to eliminate dead code, or make the path selectable
(expose use_smem_kernel as a runtime/config flag or environment option and wire
it into the existing kernel-selection logic), and if you intend to keep it for
future experiments add a concise TODO comment above use_smem_kernel explaining
why it’s disabled and when/how it should be re-enabled (reference symbols:
use_smem_kernel, gdn_verify_kernel_mtp_smem, run_gdn_verify_kernel_mtp_smem,
_get_compiled_mtp_smem_kernel).

Comment on lines +3643 to +3663
@functools.cache
def _get_compiled_mtp_kernel(
B: int,
T: int,
H: int,
HV: int,
K: int,
V: int,
pool_size: int,
cache_steps: int,
disable_state_update: bool,
cache_intermediate_states: bool,
scale: float,
use_qk_l2norm: bool,
tile_v: int, # TILE_V - configurable for batch size
vec_size: int, # 4 for full warp, 8 for half-warp
use_2row_ilp: bool = True, # Whether to use 2-row ILP (legacy, kept for compat)
ilp_rows: int = 4, # Number of ILP rows (1, 2, or 4)
):
"""Cache compiled MTP kernel for given configuration."""
return {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cache key parameters don't match the actual values being passed.

At line 3896-3913, the cache key tuple ends with (…, ilp_rows, use_smem_v), which is unpacked positionally into _get_compiled_mtp_kernel. But the function signature at line 3643 has use_2row_ilp (a stale legacy param) in position 15 and ilp_rows in position 16 — so the local ilp_rows (int) lands on use_2row_ilp (bool), and the local use_smem_v (bool) lands on ilp_rows (int).

This works today because the function body is return {} and @functools.cache just hashes the tuple, but it will silently break if anyone later uses the parameter names inside the body.

Proposed fix: align parameter names with the cache key
 `@functools.cache`
 def _get_compiled_mtp_kernel(
     B: int,
     T: int,
     H: int,
     HV: int,
     K: int,
     V: int,
     pool_size: int,
     cache_steps: int,
     disable_state_update: bool,
     cache_intermediate_states: bool,
     scale: float,
     use_qk_l2norm: bool,
     tile_v: int,  # TILE_V - configurable for batch size
     vec_size: int,  # 4 for full warp, 8 for half-warp
-    use_2row_ilp: bool = True,  # Whether to use 2-row ILP (legacy, kept for compat)
-    ilp_rows: int = 4,  # Number of ILP rows (1, 2, or 4)
+    ilp_rows: int = 4,  # Number of ILP rows (1, 2, 4, or 8)
+    use_smem_v: bool = False,  # Whether to preload v into SMEM
 ):
-    """Cache compiled MTP kernel for given configuration."""
+    """Cache compiled MTP kernel for given configuration (keyed on all args)."""
     return {}
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 3645-3645: Unused function argument: B

(ARG001)


[warning] 3646-3646: Unused function argument: T

(ARG001)


[warning] 3647-3647: Unused function argument: H

(ARG001)


[warning] 3648-3648: Unused function argument: HV

(ARG001)


[warning] 3649-3649: Unused function argument: K

(ARG001)


[warning] 3650-3650: Unused function argument: V

(ARG001)


[warning] 3651-3651: Unused function argument: pool_size

(ARG001)


[warning] 3652-3652: Unused function argument: cache_steps

(ARG001)


[warning] 3653-3653: Unused function argument: disable_state_update

(ARG001)


[warning] 3654-3654: Unused function argument: cache_intermediate_states

(ARG001)


[warning] 3655-3655: Unused function argument: scale

(ARG001)


[warning] 3656-3656: Unused function argument: use_qk_l2norm

(ARG001)


[warning] 3657-3657: Unused function argument: tile_v

(ARG001)


[warning] 3658-3658: Unused function argument: vec_size

(ARG001)


[warning] 3659-3659: Unused function argument: use_2row_ilp

(ARG001)


[warning] 3660-3660: Unused function argument: ilp_rows

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 3643 - 3663, The cache key tuple is
passing (…, ilp_rows, use_smem_v) but _get_compiled_mtp_kernel currently has a
stale use_2row_ilp parameter before ilp_rows, causing positional misbinds; fix
by aligning the function signature to match the cache key: replace the legacy
use_2row_ilp parameter with use_smem_v: bool (or reorder parameters so ilp_rows
and use_smem_v occupy the last two positions in _get_compiled_mtp_kernel) and
keep ilp_rows: int as the final int parameter; update the default for use_smem_v
if needed and remove or mark use_2row_ilp as deprecated so positional unpacking
from the cache tuple is correct.

Comment on lines +3680 to +3681
disable_state_update: bool = False,
use_qk_l2norm: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring disable_state_update default contradicts the signature.

The docstring at line 3717 says Default: ``True``` but the actual default at line 3680 is False. Since this PR intentionally changes the default to False` (state is always updated), the docstring should match.

Fix the docstring
         disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``True``.
+            If True, the initial state is not updated. Default: ``False``.

Also applies to: 3716-3718

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 3680 - 3681, The docstring for the
parameter disable_state_update is out of sync with the function signature: the
signature sets disable_state_update default to False but the docstring still
states "Default: True"; update the docstring text where disable_state_update is
documented (around the parameter docs near lines ~3716-3718) to state "Default:
False" (or otherwise match the new behavior) so the documented default matches
the signature and behavior of disable_state_update.

@ameynaik-hub ameynaik-hub force-pushed the ameyn/improve_gdn_mtp_fp32state branch from e748a27 to 577556b Compare February 22, 2026 09:02
@ameynaik-hub
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/gdn/test_decode_delta_rule.py (1)

614-627: ⚠️ Potential issue | 🔴 Critical

seed is positionally assigned to disable_state_update, silently breaking both the seed and the state-update flag.

disable_state_update was inserted at position 11 in _test_verify_kernel_mtp's signature (between cache_intermediate_states and seed), but test_verify_kernel_mtp still passes arguments positionally. The integer seed lands on disable_state_update, and the actual seed parameter of _test_verify_kernel_mtp always falls back to its default of 0.

Concrete effect with the default SEED=0:

  • disable_state_update = 0 = False → state-update is unexpectedly enabled, causing the final-state comparison block (Lines 572–582) to execute in every parametrised run.
  • The SEED environment variable is silently ignored for this test.
🐛 Proposed fix: pass disable_state_update explicitly or use keyword args
     _test_verify_kernel_mtp(
         dtype,
         batch_size,
         num_q_heads,
         num_k_heads,
         num_v_heads,
         head_size,
         seq_len,
         scale_val,
         alpha,
         beta,
         cache_intermediate_states,
+        # disable_state_update uses default (True) — no state comparison in this test
         seed,
     )

or, switch the call to keyword arguments to be robust against future signature changes:

     _test_verify_kernel_mtp(
-        dtype,
-        batch_size,
-        num_q_heads,
-        num_k_heads,
-        num_v_heads,
-        head_size,
-        seq_len,
-        scale_val,
-        alpha,
-        beta,
-        cache_intermediate_states,
-        seed,
+        dtype=dtype,
+        batch_size=batch_size,
+        num_q_heads=num_q_heads,
+        num_k_heads=num_k_heads,
+        num_v_heads=num_v_heads,
+        head_size=head_size,
+        seq_len=seq_len,
+        scale=scale_val,
+        alpha=alpha,
+        beta=beta,
+        cache_intermediate_states=cache_intermediate_states,
+        seed=seed,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 614 - 627, The call to
_test_verify_kernel_mtp is passing arguments positionally so the newly added
parameter disable_state_update (inserted before seed) receives the SEED value
and the real seed falls back to default; update the call to
_test_verify_kernel_mtp to pass disable_state_update explicitly (or convert to
keyword arguments for clarity) so seed is passed to the seed parameter and
disable_state_update gets the intended boolean value; locate the invocation of
_test_verify_kernel_mtp in the test (the call shown) and change the argument
list to either include disable_state_update=<value> or use named args for all
parameters.
flashinfer/gdn_decode.py (1)

3797-3801: ⚠️ Potential issue | 🟡 Minor

Non-contiguous intermediate_states_buffer silently discards kernel writes.

If the caller passes a non-contiguous buffer (e.g., a slice or transposed tensor), .contiguous() at Line 3800 creates a detached copy. The kernel writes into that copy, and the original buffer remains unchanged — with no error raised.

🛡️ Proposed fix: assert contiguity up-front
     if cache_intermediate_states:
         buffer_size = intermediate_states_buffer.shape[0]
         cache_steps = intermediate_states_buffer.shape[1]
+        assert intermediate_states_buffer.is_contiguous(), (
+            "intermediate_states_buffer must be contiguous; "
+            "call .contiguous() before passing it in."
+        )
         ...
         intermediate_states = (
             intermediate_states_buffer.to(torch.float32)
-            .reshape(buffer_size * cache_steps * HV, V, K)
-            .contiguous()
+            .reshape(buffer_size * cache_steps * HV, V, K)
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 3797 - 3801, The code currently calls
.contiguous() when creating intermediate_states from intermediate_states_buffer
which can silently create a detached copy and lose kernel writes; before that
conversion assert that intermediate_states_buffer.is_contiguous() (or raise a
clear ValueError) so callers must provide a contiguous tensor. Add the
contiguity check immediately before the intermediate_states = (...) block
referencing intermediate_states_buffer, and include buffer_size, cache_steps,
HV, V, K in the error message or assertion to aid debugging.
🧹 Nitpick comments (3)
flashinfer/gdn_decode.py (3)

3806-3901: Remove or gate the permanently-disabled SMEM kernel dispatch path.

use_smem_kernel = False is hardcoded on Line 3807, making the entire if use_smem_kernel: block (Lines 3809–3901) unreachable dead code. The ~90-line block adds noise and maintenance cost.

Either delete it, or promote the flag to a function parameter / environment variable so callers can opt in experimentally.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 3806 - 3901, The code sets
use_smem_kernel = False making the whole SMEM branch (the if use_smem_kernel:
block that calls _get_compiled_mtp_smem_kernel and
run_gdn_verify_kernel_mtp_smem) dead; either remove that entire branch to
eliminate maintenance noise, or expose the toggle so callers can opt in: add a
parameter (e.g., use_smem_kernel) to the surrounding function signature (with a
default False) or read from an env var (os.environ) and replace the hardcoded
assignment, and then ensure the cache key construction (cache_key_smem),
_get_compiled_mtp_smem_kernel, and compiled invocation remain correct when
enabled; also remove any now-unused symbols/imports if you delete the branch.

131-133: Misleading V2 comment: describes tile_v=128 but code caps at 64.

The comment claims tile_v=128 "halves the grid size" (num_v_tiles=1) for large batches, but get_tile_v_mtp never returns 128. The comment documents an unimplemented optimisation, which confuses future readers.

✏️ Suggested clarification
-    V2: With runtime V-row loop, tile_v=128 is feasible for large batches.
-    This halves the grid size (num_v_tiles=1 vs 2), improving L2 cache behavior.
+    tile_v=64 for large batches gives num_v_tiles=2 per head.
+    (tile_v=128 would yield num_v_tiles=1 but is not currently used.)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 131 - 133, The comment describing V2
claims tile_v=128 halves the grid size, but the function get_tile_v_mtp (and its
tile_v cap) never returns 128, which is misleading; either update the comment to
accurately state the implemented cap (tile_v ≤ 64) and its effect on
num_v_tiles/L2 behavior or implement the 128-path in get_tile_v_mtp to actually
allow tile_v=128; reference get_tile_v_mtp, tile_v, and num_v_tiles when making
the change so the comment and implementation stay consistent.

2206-2213: sVdata and sOutput are always allocated, inflating SMEM for the use_smem_v=False path.

Both allocations are unconditional, and the smem_bytes calculation in run_gdn_verify_kernel_mtp (Lines 3186–3187) always adds them regardless of use_smem_v. For use_smem_v=False with T=8, tile_v=64 this wastes 3 072 bytes per block (~35% of the total SMEM budget), reducing CTAs/SM and limiting occupancy for the common non-SMEM path.

♻️ Proposed fix in run_gdn_verify_kernel_mtp
     smem_bytes = (
         4 * T * (k_dim + 8)  # sQ
         + 4 * T * (k_dim + 8)  # sK
         + 4 * T  # sG
         + 4 * T  # sBeta
-        + 4 * T * tile_v  # sVdata (v values for all timesteps)
-        + 2 * T * tile_v  # sOutput (output accumulation in BF16)
         + 128  # alignment
     )
+    if use_smem_v:
+        smem_bytes += 4 * T * tile_v  # sVdata
+        smem_bytes += 2 * T * tile_v  # sOutput

Ideally also wrap the smem.allocate_tensor calls for sVdata/sOutput in if cutlass.const_expr(use_smem_v): if CuTe DSL supports conditional SMEM allocation, to let the compiler eliminate the dead storage for the constexpr-False specialisation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2206 - 2213, sVdata and sOutput are
always allocated which wastes SMEM when use_smem_v is False; modify the kernel
to allocate them only when use_smem_v is true by guarding the
smem.allocate_tensor calls with a compile-time conditional (e.g., if
cutlass.const_expr(use_smem_v):) or equivalent CuTe constexpr guard, and also
update run_gdn_verify_kernel_mtp's smem_bytes accounting to add their sizes only
when use_smem_v is true so the specialisation without SMEM does not reserve dead
storage; target the sVdata and sOutput allocation sites and the smem_bytes logic
in run_gdn_verify_kernel_mtp for the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 145-175: Remove the dead helper get_use_2row_ilp entirely: it is
never called and conflicts with the canonical get_ilp_rows heuristic, so delete
the get_use_2row_ilp function definition and any related comments/Docstring;
ensure all logic consumers use get_ilp_rows (verify no remaining references to
get_use_2row_ilp) to avoid divergent heuristics like the BS 5–8, T=4
disagreement with get_ilp_rows.

---

Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 3797-3801: The code currently calls .contiguous() when creating
intermediate_states from intermediate_states_buffer which can silently create a
detached copy and lose kernel writes; before that conversion assert that
intermediate_states_buffer.is_contiguous() (or raise a clear ValueError) so
callers must provide a contiguous tensor. Add the contiguity check immediately
before the intermediate_states = (...) block referencing
intermediate_states_buffer, and include buffer_size, cache_steps, HV, V, K in
the error message or assertion to aid debugging.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 614-627: The call to _test_verify_kernel_mtp is passing arguments
positionally so the newly added parameter disable_state_update (inserted before
seed) receives the SEED value and the real seed falls back to default; update
the call to _test_verify_kernel_mtp to pass disable_state_update explicitly (or
convert to keyword arguments for clarity) so seed is passed to the seed
parameter and disable_state_update gets the intended boolean value; locate the
invocation of _test_verify_kernel_mtp in the test (the call shown) and change
the argument list to either include disable_state_update=<value> or use named
args for all parameters.

---

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 3653-3673: The function _get_compiled_mtp_kernel has a stale
parameter order/name causing the cache key to bind arguments incorrectly; update
its signature to match the cache-key/order by removing the legacy use_2row_ilp
parameter (or rename it) and add the boolean use_smem_v as the final parameter
so the tail is "(..., ilp_rows: int, use_smem_v: bool)"; then update any
references inside _get_compiled_mtp_kernel and its callers to use the new
parameter name and ordering to ensure the cache key receives the correct values.
- Around line 3726-3727: The docstring for the parameter disable_state_update is
incorrect (it says Default: ``True``) while the function signature default is
False; update the parameter description in the docstring to state Default:
``False`` (and ensure any descriptive text about behavior when True/False
remains consistent) for the function/method that declares the
disable_state_update argument so the docs match the signature.

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 3806-3901: The code sets use_smem_kernel = False making the whole
SMEM branch (the if use_smem_kernel: block that calls
_get_compiled_mtp_smem_kernel and run_gdn_verify_kernel_mtp_smem) dead; either
remove that entire branch to eliminate maintenance noise, or expose the toggle
so callers can opt in: add a parameter (e.g., use_smem_kernel) to the
surrounding function signature (with a default False) or read from an env var
(os.environ) and replace the hardcoded assignment, and then ensure the cache key
construction (cache_key_smem), _get_compiled_mtp_smem_kernel, and compiled
invocation remain correct when enabled; also remove any now-unused
symbols/imports if you delete the branch.
- Around line 131-133: The comment describing V2 claims tile_v=128 halves the
grid size, but the function get_tile_v_mtp (and its tile_v cap) never returns
128, which is misleading; either update the comment to accurately state the
implemented cap (tile_v ≤ 64) and its effect on num_v_tiles/L2 behavior or
implement the 128-path in get_tile_v_mtp to actually allow tile_v=128; reference
get_tile_v_mtp, tile_v, and num_v_tiles when making the change so the comment
and implementation stay consistent.
- Around line 2206-2213: sVdata and sOutput are always allocated which wastes
SMEM when use_smem_v is False; modify the kernel to allocate them only when
use_smem_v is true by guarding the smem.allocate_tensor calls with a
compile-time conditional (e.g., if cutlass.const_expr(use_smem_v):) or
equivalent CuTe constexpr guard, and also update run_gdn_verify_kernel_mtp's
smem_bytes accounting to add their sizes only when use_smem_v is true so the
specialisation without SMEM does not reserve dead storage; target the sVdata and
sOutput allocation sites and the smem_bytes logic in run_gdn_verify_kernel_mtp
for the change.

Comment on lines +145 to +175
def get_use_2row_ilp(batch_size: int, seq_len: int) -> bool:
"""Decide whether to use 2-row ILP based on batch size and sequence length.

2-row ILP provides ~16% speedup at large batch sizes (BS>=64) across all T values.
However, at small-to-medium batch sizes (BS=8-16) with T>=4, the extra register
pressure from maintaining two h-vectors causes register spills and code bloat from
constexpr unrolling, leading to 1-6% regressions.

Empirical findings (live benchmarks vs flashinfer baseline):
- BS <= 2: Disable 2-row ILP (ILP overhead exceeds benefit for tiny batches)
- BS >= 32 (tile_v=64): Always use 2-row ILP (proven 1.06-1.19x speedup)
- BS 3-4 (tile_v=8): Use 2-row ILP (rows_per_group=2, minimal overhead)
- BS 5-8 (tile_v=16): Use 2-row for T!=4; T=4 has compiler-specific regression
- BS 9-16 (tile_v=32): Use 2-row for T<=3; T>=4 shows regression
"""
# Very small batch sizes: ILP overhead exceeds benefit
if batch_size <= 2:
return False

tile_v = get_tile_v_mtp(batch_size, seq_len)
if tile_v >= 64:
return True # Large BS (>=32): always beneficial
if batch_size <= 4:
return True # Tiny BS (3-4): minimal overhead
if tile_v == 16:
# BS=5-8: T=4 has a specific compiler regression with 2-row ILP
return seq_len != 4
if tile_v == 32:
# BS=9-16: T>=4 shows register pressure regression with 2-row
return seq_len <= 3
return True # Fallback: use 2-row
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove the dead get_use_2row_ilp function — it is never called.

get_ilp_rows (Line 178) fully supersedes this function, but both are retained. Because their heuristics disagree on several conditions (e.g., BS 5–8, T=4 returns False here vs 1 from get_ilp_rows), leaving both around is a maintenance hazard.

🗑️ Proposed removal
-def get_use_2row_ilp(batch_size: int, seq_len: int) -> bool:
-    """Decide whether to use 2-row ILP based on batch size and sequence length.
-    ...
-    """
-    # Very small batch sizes: ILP overhead exceeds benefit
-    if batch_size <= 2:
-        return False
-
-    tile_v = get_tile_v_mtp(batch_size, seq_len)
-    if tile_v >= 64:
-        return True  # Large BS (>=32): always beneficial
-    if batch_size <= 4:
-        return True  # Tiny BS (3-4): minimal overhead
-    if tile_v == 16:
-        # BS=5-8: T=4 has a specific compiler regression with 2-row ILP
-        return seq_len != 4
-    if tile_v == 32:
-        # BS=9-16: T>=4 shows register pressure regression with 2-row
-        return seq_len <= 3
-    return True  # Fallback: use 2-row
-
-
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_use_2row_ilp(batch_size: int, seq_len: int) -> bool:
"""Decide whether to use 2-row ILP based on batch size and sequence length.
2-row ILP provides ~16% speedup at large batch sizes (BS>=64) across all T values.
However, at small-to-medium batch sizes (BS=8-16) with T>=4, the extra register
pressure from maintaining two h-vectors causes register spills and code bloat from
constexpr unrolling, leading to 1-6% regressions.
Empirical findings (live benchmarks vs flashinfer baseline):
- BS <= 2: Disable 2-row ILP (ILP overhead exceeds benefit for tiny batches)
- BS >= 32 (tile_v=64): Always use 2-row ILP (proven 1.06-1.19x speedup)
- BS 3-4 (tile_v=8): Use 2-row ILP (rows_per_group=2, minimal overhead)
- BS 5-8 (tile_v=16): Use 2-row for T!=4; T=4 has compiler-specific regression
- BS 9-16 (tile_v=32): Use 2-row for T<=3; T>=4 shows regression
"""
# Very small batch sizes: ILP overhead exceeds benefit
if batch_size <= 2:
return False
tile_v = get_tile_v_mtp(batch_size, seq_len)
if tile_v >= 64:
return True # Large BS (>=32): always beneficial
if batch_size <= 4:
return True # Tiny BS (3-4): minimal overhead
if tile_v == 16:
# BS=5-8: T=4 has a specific compiler regression with 2-row ILP
return seq_len != 4
if tile_v == 32:
# BS=9-16: T>=4 shows register pressure regression with 2-row
return seq_len <= 3
return True # Fallback: use 2-row
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 145 - 175, Remove the dead helper
get_use_2row_ilp entirely: it is never called and conflicts with the canonical
get_ilp_rows heuristic, so delete the get_use_2row_ilp function definition and
any related comments/Docstring; ensure all logic consumers use get_ilp_rows
(verify no remaining references to get_use_2row_ilp) to avoid divergent
heuristics like the BS 5–8, T=4 disagreement with get_ilp_rows.

r_h = cute.make_rmem_tensor(
cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32
)
r_h2 = cute.make_rmem_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we create a rmem tensor with shape ((8, vec_size))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks added the change. let me know if it LGTY.

@vadiklyutiy
Copy link

What is T in perf results?

@ameynaik-hub
Copy link
Contributor Author

ameynaik-hub commented Feb 23, 2026

What is T in perf results?

seqlen

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/bench_gdn_decode.py (1)

2418-2468: ⚠️ Potential issue | 🟠 Major

--use-improved / --use-improved-2 are silently ignored in standalone MTP mode.

run_flashinfer_only_benchmark always routes through bench_gdn_mtp → gated_delta_rule_mtp (original) and never consults args.use_improved or args.use_improved_2. A user running:

python benchmarks/bench_gdn_decode.py --version mtp --use-improved

gets timings for the original kernel without any warning, potentially leading to incorrect performance conclusions. The flags are only honoured when --compare is also passed.

Add routing in the mtp branch of run_flashinfer_only_benchmark, or at minimum emit a warning/error when the improved flags are set but --compare is absent.

🛠️ Minimal guard to surface the issue
 def run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm):
     """Run FlashInfer-only benchmarks."""
+    if getattr(args, "use_improved", False) or getattr(args, "use_improved_2", False):
+        print(
+            "Warning: --use-improved / --use-improved-2 have no effect without --compare. "
+            "Add --compare to benchmark against Triton with the selected implementation."
+        )
     # Determine which versions to benchmark
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2418 - 2468, The MTP branch in
run_flashinfer_only_benchmark always calls bench_gdn_mtp which ends up using the
original gated_delta_rule_mtp, ignoring args.use_improved and
args.use_improved_2; update the mtp branch to (a) consult args.use_improved and
args.use_improved_2 and route to the corresponding improved MTP implementation
(or pass a flag through to bench_gdn_mtp so it chooses the improved
gated_delta_rule_mtp variant), and (b) if either improved flag is set but
args.compare is false, emit a clear warning/error indicating the improved
kernels will be ignored unless --compare is used. Reference
run_flashinfer_only_benchmark, bench_gdn_mtp, and gated_delta_rule_mtp when
making the change.
🧹 Nitpick comments (2)
benchmarks/bench_gdn_decode.py (2)

1589-1595: NameError if bench_mtp_comparison is called directly with an unavailable kernel.

gated_delta_rule_mtp_improved and gated_delta_rule_mtp_improved_2 are unbound when their imports fail. The current sole caller (run_comparison_benchmark) guards with RuntimeError first, but the function itself has no availability check. Adding local guards makes bench_mtp_comparison safe to call independently.

🛡️ Proposed defensive guard
+    if use_improved_2 and not GDN_DECODE_IMPROVED_2_AVAILABLE:
+        raise RuntimeError("gdn_decode_improved_2 is not available.")
+    if use_improved and not GDN_DECODE_IMPROVED_AVAILABLE:
+        raise RuntimeError("gdn_decode_improved is not available.")
+
     if use_improved_2:
         mtp_func = gated_delta_rule_mtp_improved_2
     elif use_improved:
         mtp_func = gated_delta_rule_mtp_improved
     else:
         mtp_func = gated_delta_rule_mtp
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 1589 - 1595, The selection block
in bench_mtp_comparison can raise NameError if gated_delta_rule_mtp_improved or
gated_delta_rule_mtp_improved_2 weren't imported; add defensive availability
checks before assigning mtp_func (e.g., test presence in globals() or try/except
NameError) and raise a clear RuntimeError if the requested improved kernel is
unavailable, referencing the function name bench_mtp_comparison and the kernel
symbols gated_delta_rule_mtp_improved, gated_delta_rule_mtp_improved_2, and
gated_delta_rule_mtp so callers that invoke bench_mtp_comparison directly cannot
trigger an unbound name error.

2556-2565: Ruff TRY003: long RuntimeError messages should be defined in a custom exception class.

Static analysis flags both RuntimeError raises (lines 2557-2560 and 2562-2565) for specifying long messages outside the exception class.

♻️ Proposed fix
-        if args.use_improved and not GDN_DECODE_IMPROVED_AVAILABLE:
-            raise RuntimeError(
-                "Improved GDN decode not available. "
-                "Make sure flashinfer/gdn_decode_improved.py exists."
-            )
-        if args.use_improved_2 and not GDN_DECODE_IMPROVED_2_AVAILABLE:
-            raise RuntimeError(
-                "Improved v2 GDN decode not available. "
-                "Make sure flashinfer/gdn_decode_improved_2.py exists."
-            )
+        if args.use_improved and not GDN_DECODE_IMPROVED_AVAILABLE:
+            raise RuntimeError("gdn_decode_improved not available; ensure flashinfer/gdn_decode_improved.py exists.")
+        if args.use_improved_2 and not GDN_DECODE_IMPROVED_2_AVAILABLE:
+            raise RuntimeError("gdn_decode_improved_2 not available; ensure flashinfer/gdn_decode_improved_2.py exists.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_gdn_decode.py` around lines 2556 - 2565, Replace the two
long RuntimeError raises with custom exception classes that encapsulate the long
messages (e.g., define GdnDecodeNotAvailableError and
GdnDecodeV2NotAvailableError) and raise those instead; check the conditions
using args.use_improved / GDN_DECODE_IMPROVED_AVAILABLE and args.use_improved_2
/ GDN_DECODE_IMPROVED_2_AVAILABLE as before, but move each detailed message into
the corresponding exception class (defined in this module or a shared exceptions
module) and raise the custom exception instance to satisfy Ruff TRY003.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 636-670: The module-level pytest mark named pytestmark is applying
pytest.mark.skip which prevents the test function
test_mtp_fp32_state_with_cache_and_state_update (and its parametrizations) from
running; remove or change the module-level skip so this test executes in
CI—either delete or scope the skip to only specific tests (replace with
conditional skip using pytest.mark.skipif with an environment/flag check) so
that test_mtp_fp32_state_with_cache_and_state_update and its parametrized cases
run during normal CI.

---

Outside diff comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 2418-2468: The MTP branch in run_flashinfer_only_benchmark always
calls bench_gdn_mtp which ends up using the original gated_delta_rule_mtp,
ignoring args.use_improved and args.use_improved_2; update the mtp branch to (a)
consult args.use_improved and args.use_improved_2 and route to the corresponding
improved MTP implementation (or pass a flag through to bench_gdn_mtp so it
chooses the improved gated_delta_rule_mtp variant), and (b) if either improved
flag is set but args.compare is false, emit a clear warning/error indicating the
improved kernels will be ignored unless --compare is used. Reference
run_flashinfer_only_benchmark, bench_gdn_mtp, and gated_delta_rule_mtp when
making the change.

---

Nitpick comments:
In `@benchmarks/bench_gdn_decode.py`:
- Around line 1589-1595: The selection block in bench_mtp_comparison can raise
NameError if gated_delta_rule_mtp_improved or gated_delta_rule_mtp_improved_2
weren't imported; add defensive availability checks before assigning mtp_func
(e.g., test presence in globals() or try/except NameError) and raise a clear
RuntimeError if the requested improved kernel is unavailable, referencing the
function name bench_mtp_comparison and the kernel symbols
gated_delta_rule_mtp_improved, gated_delta_rule_mtp_improved_2, and
gated_delta_rule_mtp so callers that invoke bench_mtp_comparison directly cannot
trigger an unbound name error.
- Around line 2556-2565: Replace the two long RuntimeError raises with custom
exception classes that encapsulate the long messages (e.g., define
GdnDecodeNotAvailableError and GdnDecodeV2NotAvailableError) and raise those
instead; check the conditions using args.use_improved /
GDN_DECODE_IMPROVED_AVAILABLE and args.use_improved_2 /
GDN_DECODE_IMPROVED_2_AVAILABLE as before, but move each detailed message into
the corresponding exception class (defined in this module or a shared exceptions
module) and raise the custom exception instance to satisfy Ruff TRY003.

Comment on lines +636 to +670
@pytest.mark.parametrize("seq_len", [2, 3, 4, 5, 6, 7, 8])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512])
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_mtp_fp32_state_with_cache_and_state_update(
dtype: str,
batch_size: int,
seq_len: int,
seed: int = int(os.environ.get("SEED", "0")),
):
"""
Comprehensive MTP test with FP32 state, intermediate caching ON, state update ON.

This tests the production configuration:
- FP32 h state (not bf16)
- cache_intermediate_states=True
- disable_state_update=False (h is updated)
- All batch sizes: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512
- All sequence lengths: 2, 3, 4, 5, 6, 7, 8
"""
scale_val = 1.0 / math.sqrt(128) # head_size=128
_test_verify_kernel_mtp(
dtype=dtype,
batch_size=batch_size,
num_q_heads=16,
num_k_heads=16,
num_v_heads=32,
head_size=128,
seq_len=seq_len,
scale=scale_val,
alpha=True,
beta=True,
cache_intermediate_states=True,
disable_state_update=False, # State update ON
seed=seed,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

New comprehensive test skipped by module-level pytestmark.

Line 26 applies pytest.mark.skip to the entire module, so test_mtp_fp32_state_with_cache_and_state_update (70 parametrized cases) will not execute in CI until that skip is removed. The test itself is logically sound, but the coverage claimed by the PR won't materialize while the module-level skip remains.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 636 - 670, The module-level
pytest mark named pytestmark is applying pytest.mark.skip which prevents the
test function test_mtp_fp32_state_with_cache_and_state_update (and its
parametrizations) from running; remove or change the module-level skip so this
test executes in CI—either delete or scope the skip to only specific tests
(replace with conditional skip using pytest.mark.skipif with an environment/flag
check) so that test_mtp_fp32_state_with_cache_and_state_update and its
parametrized cases run during normal CI.

@ameynaik-hub ameynaik-hub force-pushed the ameyn/improve_gdn_mtp_fp32state branch from 3ccb7b3 to 62ac6cb Compare February 23, 2026 04:59
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/gdn/test_decode_delta_rule.py (1)

614-627: ⚠️ Potential issue | 🔴 Critical

seed is silently passed to disable_state_update due to the new positional parameter insertion.

After the new disable_state_update: bool = True parameter was inserted at position 12 (between cache_intermediate_states and seed) in _test_verify_kernel_mtp, the existing call site still passes 12 positional arguments ending in seed. The mapping is now:

Call arg Parameter received
cache_intermediate_states (pos 11) cache_intermediate_states
seed (pos 12) disable_state_update
(nothing) seed → falls back to default 0

Consequences:

  • Default SEED=0: disable_state_update=0 (falsy) → not disable_state_update=True → unintended final-state comparison is triggered; the kernel is invoked with disable_state_update=False, so state IS written back. The test may pass or silently produce a wrong verdict.
  • Non-zero SEED: disable_state_update=<seed> (truthy) → final-state comparison is skipped; and since the seed parameter receives 0, the random seed is also silently ignored.
🐛 Proposed fix — use a keyword argument for `seed`
     _test_verify_kernel_mtp(
         dtype,
         batch_size,
         num_q_heads,
         num_k_heads,
         num_v_heads,
         head_size,
         seq_len,
         scale_val,
         alpha,
         beta,
         cache_intermediate_states,
-        seed,
+        seed=seed,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 614 - 627, The call to
_test_verify_kernel_mtp is passing positional args so the inserted parameter
disable_state_update now receives the test's seed value; update the call to
explicitly pass the seed as a keyword (seed=seed) or explicitly pass
disable_state_update (e.g. disable_state_update=True/False) and then seed as the
named argument so that _test_verify_kernel_mtp(seed=...) receives the intended
seed and disable_state_update is not shadowed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_gdn_decode.py`:
- Line 1580: bench_mtp_comparison currently hardcodes disable_state_update=False
so the CLI --update-state flag is ignored when running with --compare; update
bench_mtp_comparison to accept a disable_state_update (or update_state)
parameter and propagate the flag from run_comparison_benchmark (where CLI flags
are parsed), and also adjust the comparison header generation (the code that
prints the MTP comparison header) to include the update_state status similar to
the non-compare MTP header; ensure run_comparison_benchmark forwards the parsed
flag into bench_mtp_comparison and use that value when building the comparison
header/display.

---

Outside diff comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 614-627: The call to _test_verify_kernel_mtp is passing positional
args so the inserted parameter disable_state_update now receives the test's seed
value; update the call to explicitly pass the seed as a keyword (seed=seed) or
explicitly pass disable_state_update (e.g. disable_state_update=True/False) and
then seed as the named argument so that _test_verify_kernel_mtp(seed=...)
receives the intended seed and disable_state_update is not shadowed.

---

Duplicate comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 636-670: The module-level pytestmark (pytestmark) is skipping the
entire test file so the new test test_mtp_fp32_state_with_cache_and_state_update
(and others) never run; remove or narrow the pytestmark skip so it no longer
applies globally—either delete or change the module-level pytestmark to a
conditional/targeted skip, or apply skip markers only to specific flaky tests
rather than the whole module, ensuring
test_mtp_fp32_state_with_cache_and_state_update is executed in CI.

@ameynaik-hub
Copy link
Contributor Author

ameynaik-hub commented Feb 24, 2026

Result with intermediate caching enabled but h update disabled. because with MTP we want enable cache enabled but dont want to overwrite initial state since it is required in case draft doesnt speculate correctly.

Main Branch Kernel Times (µs)

BS \ T 2 3 4 5 6 7 8
1 5.70 7.04 8.29 9.76 10.94 12.53 13.98
2 6.56 8.13 9.86 11.62 13.15 14.91 16.61
4 9.57 11.78 14.05 16.18 18.53 28.72 23.65
8 14.24 17.79 28.16 33.23 38.90 44.06 49.71
16 23.07 38.29 46.70 55.87 64.50 74.48 84.77
32 41.60 67.65 84.58 101.52 121.17 136.34 154.26
64 77.60 113.41 142.64 175.58 205.82 236.59 274.37
128 148.74 206.43 260.35 317.28 371.31 430.32 490.70
256 281.86 403.75 503.42 617.65 738.83 862.19 998.87
512 549.49 809.55 1018.71 1276.85 1525.92 1794.26 2057.52

Optimized Kernel Times (µs)

BS \ T 2 3 4 5 6 7 8
1 5.66 7.04 8.27 9.76 10.94 12.54 13.92
2 6.56 8.19 9.81 11.62 13.18 14.94 16.61
4 9.57 11.84 13.98 16.18 18.62 28.69 23.62
8 14.08 17.60 28.16 24.96 29.34 32.99 37.74
16 22.94 27.97 46.82 44.00 64.48 74.10 85.23
32 40.34 55.78 64.03 76.54 89.34 105.38 119.97
64 68.48 92.61 119.52 141.63 166.59 194.02 218.59
128 129.47 176.83 223.17 269.20 317.36 369.70 418.94
256 250.50 341.54 432.88 524.29 620.11 723.78 823.74
512 492.62 671.73 854.71 1041.25 1245.30 1458.29 1693.33

Speedup (main_time / optimized_time)

Values > 1.0 = optimized kernel is faster

BS \ T 2 3 4 5 6 7 8 Avg
1 1.01 1.00 1.00 1.00 1.00 1.00 1.00 1.00
2 1.00 0.99 1.01 1.00 1.00 1.00 1.00 1.00
4 1.00 0.99 1.01 1.00 1.00 1.00 1.00 1.00
8 1.01 1.01 1.00 1.33 1.33 1.34 1.32 1.19
16 1.01 1.37 1.00 1.27 1.00 1.01 0.99 1.09
32 1.03 1.21 1.32 1.33 1.36 1.29 1.29 1.26
64 1.13 1.22 1.19 1.24 1.24 1.22 1.26 1.21
128 1.15 1.17 1.17 1.18 1.17 1.16 1.17 1.17
256 1.13 1.18 1.16 1.18 1.19 1.19 1.21 1.18
512 1.12 1.21 1.19 1.23 1.23 1.23 1.22 1.20

@ameynaik-hub
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@ameynaik-hub is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@ameynaik-hub
Copy link
Contributor Author

GDN MTP Kernel Benchmark Results

GPU: NVIDIA B200
Model: Qwen 3.5
Dimensions: q_heads=16, k_heads=16, v_heads=64, head_dim=128
Data type: bfloat16 (FP32 h-state)
Settings:

  • T=1: Single token decode (pretranspose layout)
  • T≥2: MTP with intermediate states cached, state update disabled

Benchmark Command

# T=1 Decode
python benchmarks/bench_gdn_decode.py \
    --num-q-heads 16 --num-k-heads 16 --num-v-heads 64 --head-size 128 \
    --version mtp --compare \
    --batch-size 1 2 4 8 16 32 64 128 256 \
    --seq-len 1 \
    --iters 100 --warmup 10

# T=2..8 MTP (intermediate caching ON, state update OFF)
python benchmarks/bench_gdn_decode.py \
    --num-q-heads 16 --num-k-heads 16 --num-v-heads 64 --head-size 128 \
    --version mtp --compare \
    --batch-size 1 2 4 8 16 32 64 128 256 \
    --seq-len 2 3 4 5 6 7 8 \
    --cache-intermediate-states \
    --iters 100 --warmup 10

Raw Kernel Times (µs) - T (rows) vs BS (columns)

Triton
┌──────┬───────┬───────┬───────┬────────┬────────┬────────┬────────┬─────────┬─────────┐
│ T\BS │   1   │   2   │   4   │   8    │   16   │   32   │   64   │   128   │   256   │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 1    │ 5.57  │ 7.74  │ 12.54 │ 19.94  │ 36.45  │ 70.75  │ 134.78 │ 275.97  │ 543.36  │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 2    │ 7.87  │ 13.41 │ 20.90 │ 35.18  │ 66.58  │ 137.82 │ 272.79 │ 538.05  │ 1079.60 │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 3    │ 10.05 │ 17.41 │ 26.94 │ 45.57  │ 90.02  │ 182.77 │ 357.54 │ 704.16  │ 1400.87 │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 4    │ 12.19 │ 21.23 │ 32.58 │ 55.94  │ 114.85 │ 227.42 │ 445.54 │ 887.56  │ 1750.12 │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 5    │ 14.35 │ 25.06 │ 38.66 │ 66.86  │ 140.02 │ 273.26 │ 535.86 │ 1059.24 │ 2117.32 │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 6    │ 16.54 │ 28.94 │ 44.43 │ 77.70  │ 165.12 │ 318.87 │ 624.58 │ 1240.88 │ 2478.71 │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 7    │ 18.67 │ 32.82 │ 51.22 │ 89.25  │ 189.49 │ 364.90 │ 716.27 │ 1426.53 │ 2830.11 │
├──────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┼─────────┤
│ 8    │ 20.96 │ 36.80 │ 59.31 │ 101.52 │ 214.22 │ 410.82 │ 811.11 │ 1594.97 │ 3198.43 │
└──────┴───────┴───────┴───────┴────────┴────────┴────────┴────────┴─────────┴─────────┘
Main Branch (FlashInfer)
┌──────┬───────┬───────┬───────┬───────┬────────┬────────┬────────┬─────────┬─────────┐
│ T\BS │   1   │   2   │   4   │   8   │   16   │   32   │   64   │   128   │   256   │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 1    │ 5.65  │ 8.48  │ 13.18 │ 21.31 │ 37.70  │ 71.57  │ 137.34 │ 271.36  │ 531.87  │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 2    │ 7.36  │ 11.78 │ 17.28 │ 28.35 │ 51.84  │ 105.70 │ 191.87 │ 373.36  │ 730.24  │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 3    │ 8.96  │ 15.01 │ 22.03 │ 36.67 │ 72.06  │ 138.26 │ 242.00 │ 470.26  │ 932.96  │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 4    │ 10.75 │ 17.73 │ 26.53 │ 49.84 │ 89.27  │ 168.93 │ 296.77 │ 576.96  │ 1145.73 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 5    │ 12.45 │ 20.86 │ 36.80 │ 59.55 │ 106.13 │ 201.38 │ 354.21 │ 689.76  │ 1383.81 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 6    │ 13.86 │ 24.03 │ 35.58 │ 69.09 │ 126.37 │ 234.14 │ 411.38 │ 805.09  │ 1626.69 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 7    │ 15.55 │ 27.87 │ 48.16 │ 78.77 │ 142.69 │ 268.88 │ 470.71 │ 927.28  │ 1890.87 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼─────────┼─────────┤
│ 8    │ 17.70 │ 31.01 │ 53.87 │ 88.80 │ 161.44 │ 302.70 │ 530.02 │ 1059.38 │ 2149.64 │
└──────┴───────┴───────┴───────┴───────┴────────┴────────┴────────┴─────────┴─────────┘
Current Branch (FlashInfer)
┌──────┬───────┬───────┬───────┬───────┬────────┬────────┬────────┬────────┬─────────┐
│ T\BS │   1   │   2   │   4   │   8   │   16   │   32   │   64   │  128   │   256   │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 1    │ 5.63  │ 8.51  │ 13.15 │ 19.57 │ 34.62  │ 65.63  │ 126.48 │ 246.45 │ 487.15  │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 2    │ 7.39  │ 11.84 │ 17.28 │ 25.12 │ 45.47  │ 88.75  │ 169.60 │ 330.99 │ 653.97  │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 3    │ 9.04  │ 14.94 │ 22.06 │ 32.91 │ 57.63  │ 116.66 │ 217.34 │ 423.41 │ 836.58  │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 4    │ 10.75 │ 17.79 │ 26.51 │ 49.57 │ 88.93  │ 141.07 │ 265.01 │ 516.26 │ 1021.01 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 5    │ 12.45 │ 20.83 │ 36.83 │ 49.02 │ 106.30 │ 164.78 │ 312.64 │ 609.20 │ 1207.91 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 6    │ 13.95 │ 24.10 │ 35.47 │ 58.02 │ 127.36 │ 190.26 │ 361.87 │ 706.15 │ 1400.90 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 7    │ 15.52 │ 28.10 │ 48.19 │ 63.97 │ 144.27 │ 214.38 │ 411.82 │ 805.65 │ 1602.34 │
├──────┼───────┼───────┼───────┼───────┼────────┼────────┼────────┼────────┼─────────┤
│ 8    │ 17.66 │ 31.10 │ 53.66 │ 72.06 │ 160.53 │ 240.75 │ 460.13 │ 903.76 │ 1801.93 │
└──────┴───────┴───────┴───────┴───────┴────────┴────────┴────────┴────────┴─────────┘
---
Speedup: Current Branch vs Triton

Values > 1.0 mean Current Branch is faster
┌──────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ T\BS │   1   │   2   │   4   │   8   │  16   │  32   │  64   │  128  │  256  │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 1    │ 0.99x │ 0.91x │ 0.95x │ 1.02x │ 1.05x │ 1.08x │ 1.07x │ 1.12x │ 1.12x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 2    │ 1.06x │ 1.13x │ 1.21x │ 1.40x │ 1.46x │ 1.55x │ 1.61x │ 1.63x │ 1.65x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 3    │ 1.11x │ 1.16x │ 1.22x │ 1.38x │ 1.56x │ 1.57x │ 1.65x │ 1.66x │ 1.67x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 4    │ 1.13x │ 1.19x │ 1.23x │ 1.13x │ 1.29x │ 1.61x │ 1.68x │ 1.72x │ 1.71x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 5    │ 1.15x │ 1.20x │ 1.05x │ 1.36x │ 1.32x │ 1.66x │ 1.71x │ 1.74x │ 1.75x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 6    │ 1.19x │ 1.20x │ 1.25x │ 1.34x │ 1.30x │ 1.68x │ 1.73x │ 1.76x │ 1.77x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 7    │ 1.20x │ 1.17x │ 1.06x │ 1.40x │ 1.31x │ 1.70x │ 1.74x │ 1.77x │ 1.77x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 8    │ 1.19x │ 1.18x │ 1.11x │ 1.41x │ 1.33x │ 1.71x │ 1.76x │ 1.76x │ 1.78x │
└──────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

---
Speedup: Current Branch vs Main Branch

Values > 1.0 mean Current Branch is faster
┌──────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ T\BS │   1   │   2   │   4   │   8   │  16   │  32   │  64   │  128  │  256  │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 1    │ 1.00x │ 1.00x │ 1.00x │ 1.09x │ 1.09x │ 1.09x │ 1.09x │ 1.10x │ 1.09x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 2    │ 1.00x │ 0.99x │ 1.00x │ 1.13x │ 1.14x │ 1.19x │ 1.13x │ 1.13x │ 1.12x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 3    │ 0.99x │ 1.00x │ 1.00x │ 1.11x │ 1.25x │ 1.18x │ 1.11x │ 1.11x │ 1.12x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 4    │ 1.00x │ 1.00x │ 1.00x │ 1.01x │ 1.00x │ 1.20x │ 1.12x │ 1.12x │ 1.12x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 5    │ 1.00x │ 1.00x │ 1.00x │ 1.21x │ 1.00x │ 1.22x │ 1.13x │ 1.13x │ 1.15x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 6    │ 0.99x │ 1.00x │ 1.00x │ 1.19x │ 0.99x │ 1.23x │ 1.14x │ 1.14x │ 1.16x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 7    │ 1.00x │ 0.99x │ 1.00x │ 1.23x │ 0.99x │ 1.25x │ 1.14x │ 1.15x │ 1.18x │
├──────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┼───────┤
│ 8    │ 1.00x │ 1.00x │ 1.00x │ 1.23x │ 1.01x │ 1.26x │ 1.15x │ 1.17x │ 1.19x │
└──────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

@yongwww
Copy link
Member

yongwww commented Feb 26, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !349 has been created, and the CI pipeline #44898535 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44898535: 10/20 passed



@flashinfer_api
def gated_delta_rule_mtp(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is getting really long. Is it possible to separate the APIs out from the kernel implementation and only keep the APIs in this file?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)

486-493: ⚠️ Potential issue | 🟠 Major

Missing availability check for MTP helpers.

If _MTP_AVAILABLE is False, both get_tile_v_mtp and get_vec_size_mtp are None. Calling them at lines 492–493 will raise an unhelpful TypeError: 'NoneType' object is not callable.

Add an early availability check to provide a clear error message:

 def gated_delta_rule_mtp(
     ...
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     ...
     # Validate input shapes
     B, T, H, K = q.shape
     _, _, HV, V = v.shape
     pool_size = initial_state.shape[0]

+    if not _MTP_AVAILABLE:
+        raise ImportError(
+            "MTP kernel backend not available. "
+            "Please ensure flashinfer is properly installed with CUDA support."
+        )
+
     # Dynamic TILE_V and vec_size selection based on batch size and sequence length
     tile_v = get_tile_v_mtp(B, T)
     vec_size = get_vec_size_mtp(B, T)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 486 - 493, The code calls
get_tile_v_mtp and get_vec_size_mtp without checking _MTP_AVAILABLE, which will
raise a TypeError if those helpers are None; add an early availability check
before calling them (e.g., at the start of the validation block around B, T, H,
K) that raises a clear RuntimeError or ValueError when _MTP_AVAILABLE is False,
mentioning _MTP_AVAILABLE and the two helper names, and only call get_tile_v_mtp
and get_vec_size_mtp when _MTP_AVAILABLE is True so tile_v and vec_size are
computed safely.
🧹 Nitpick comments (7)
flashinfer/gdn_kernels/gdn_decode_pretranspose.py (4)

799-813: Cache function returns empty dict; all parameters serve only as cache keys.

The function body is return {} and all parameters are intentionally unused (they form the cache key via @functools.cache). This pattern is valid but can be confusing. Consider adding a brief docstring clarifying that parameters are cache-key-only.

 `@functools.cache`
 def _get_compiled_decode_kernel(
     B: int,
     ...
     use_qk_l2norm: bool,
 ):
-    """Cache compiled kernel for given configuration (pretranspose version)."""
+    """Return a cache dict keyed by the configuration tuple.
+    
+    All parameters form the cache key via `@functools.cache`; the returned
+    dict is populated by run_pretranspose_decode on first call.
+    """
     # This will be populated on first call
     return {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 799 - 813,
The function _get_compiled_decode_kernel currently returns an empty dict and
uses its parameters only to form the functools.cache key, which is confusing;
update the function docstring to explicitly state that B, T, H, HV, K, V, dtype,
scale, and use_qk_l2norm are only used as cache keys and that the function
intentionally returns an empty placeholder (to be populated later on first call)
so readers understand the pattern and won't expect parameter usage in the body.

741-742: Dead code: expression result discarded.

Same issue as in run_gdn_decode_kernel_small_batch_pretranspose — this expression computes but never uses the result.

     num_v_tiles = cute.ceil_div(v_dim, TILE_V)
-    v_dim * k_dim * batch_size * 4 / 1024 / 1024
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 741 - 742, In
the function that computes num_v_tiles (same pattern as
run_gdn_decode_kernel_small_batch_pretranspose), remove the dead standalone
expression "v_dim * k_dim * batch_size * 4 / 1024 / 1024" or convert it to a
meaningful use (e.g., assign to a clearly named variable like estimated_mb or
log it) so the computed value is not discarded; update the code around
num_v_tiles to either delete the stray expression or replace it with an
assignment or logging call to preserve the intent.

638-639: Dead code: expression result discarded.

Line 639 computes a value but doesn't store or use it. This appears to be a debugging artifact.

     num_v_tiles = cute.ceil_div(v_dim, TILE_V)
-    v_dim * k_dim * batch_size * 4 / 1024 / 1024
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 638 - 639,
Remove the dead expression that computes v_dim * k_dim * batch_size * 4 / 1024 /
1024 (likely a leftover debug calculation) in gdn_decode_pretranspose.py; either
delete that line entirely or replace it with an assigned variable (e.g., mem_mb)
or a logging call if the computed memory in MB is intended to be used—refer to
the surrounding identifiers num_v_tiles, v_dim, k_dim, batch_size, and TILE_V to
locate the exact spot.

317-585: gdn_decode_kernel_big_batch_pretranspose and its launcher are never invoked.

At line 873, the entry point unconditionally selects run_gdn_decode_kernel_small_batch_pretranspose:

# Always use 8-CTA architecture (benchmarks show it's better for all batch sizes)
run_func = run_gdn_decode_kernel_small_batch_pretranspose

This means the big-batch kernel implementation (lines 317–585) and launcher (lines 691–791) are dead code. If benchmarks confirmed small-batch is always better, consider removing the unused code or adding a TODO/comment explaining retention for future use.

Also applies to: 691-791, 872-873

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py` around lines 317 - 585,
The big-batch kernel gdn_decode_kernel_big_batch_pretranspose and its launcher
are never used because run_func is unconditionally set to
run_gdn_decode_kernel_small_batch_pretranspose; either remove the unused
big-kernel and its launcher (clean up dead code) or restore a selection path:
update the launcher logic that sets run_func to choose between
run_gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_small_batch_pretranspose based on batch-size (B) or a
config flag, or if you keep the big kernel for future use, add a clear TODO
comment above gdn_decode_kernel_big_batch_pretranspose and its launcher
explaining why it is retained and under what condition it should be re-enabled.
flashinfer/gdn_kernels/gdn_decode_nontranspose.py (2)

541-544: Dead code: expression result discarded, unused variable.

Line 543 accesses h0_indices.layout.shape[0] but discards the result. Additionally, k_dim is unpacked but never used.

-    batch_hv_dim, k_dim, v_dim = h0_source.layout.shape
-    h0_indices.layout.shape[0]
+    batch_hv_dim, _k_dim, v_dim = h0_source.layout.shape
     batch_size = batch_hv_dim  # batch_hv_dim = B * HV
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_nontranspose.py` around lines 541 - 544,
Remove the dead expression and unused unpacking: delete the stray statement that
just reads h0_indices.layout.shape[0] (it discards the result) and either use
k_dim/v_dim where needed or stop unpacking unused k_dim (e.g., unpack only the
used dimensions or rename unused to _k_dim) so that batch_hv_dim, v_dim and
batch_size remain correct; update the block around h0_source.layout.shape and
the variables batch_hv_dim, k_dim, v_dim and h0_indices to reflect the change.

622-625: Same issue: dead expression and unused k_dim.

-    batch_hv_dim, k_dim, v_dim = h0_source.layout.shape
-    h0_indices.layout.shape[0]
+    batch_hv_dim, _k_dim, v_dim = h0_source.layout.shape
     batch_size = batch_hv_dim  # batch_hv_dim = B * HV
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_nontranspose.py` around lines 622 - 625,
Remove the dead standalone expression "h0_indices.layout.shape[0]" and either
use or mark the unused k_dim from the unpack of h0_source.layout.shape;
specifically, replace the unused variable name k_dim with _k_dim (or otherwise
consume it) and delete the stray h0_indices.layout.shape[0] line so only
meaningful assignments remain (batch_hv_dim, _k_dim, v_dim =
h0_source.layout.shape and batch_size = batch_hv_dim).
flashinfer/gdn_decode.py (1)

241-260: Kernel invocations assume backends are available without explicit checks.

All three APIs (gated_delta_rule_decode_pretranspose, gated_delta_rule_decode, gated_delta_rule_mtp) call their respective run_* functions without verifying availability. While the MTP case is the most problematic (helper functions called first), the other two will also produce confusing errors if backends aren't installed.

Consider adding availability guards at the start of each function, similar to the MTP fix suggested above.

Also applies to: 382-401, 550-577

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 241 - 260, Add explicit availability
checks at the start of each public API (`gated_delta_rule_decode_pretranspose`,
`gated_delta_rule_decode`, and `gated_delta_rule_mtp`) before calling their
respective helpers (`run_pretranspose_decode`, `run_decode`, `run_mtp_decode`):
detect whether the native backend or helper is present (e.g., via the same
availability predicate used in the MTP fix or by checking the imported helper),
and if missing raise a clear ImportError/RuntimeError with guidance on
installing the backend instead of proceeding to call the run_* function; apply
the same guard pattern to the other affected ranges (the other occurrences
around lines 382-401 and 550-577) so no run_* helper is invoked when its backend
is unavailable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 486-493: The code calls get_tile_v_mtp and get_vec_size_mtp
without checking _MTP_AVAILABLE, which will raise a TypeError if those helpers
are None; add an early availability check before calling them (e.g., at the
start of the validation block around B, T, H, K) that raises a clear
RuntimeError or ValueError when _MTP_AVAILABLE is False, mentioning
_MTP_AVAILABLE and the two helper names, and only call get_tile_v_mtp and
get_vec_size_mtp when _MTP_AVAILABLE is True so tile_v and vec_size are computed
safely.

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 241-260: Add explicit availability checks at the start of each
public API (`gated_delta_rule_decode_pretranspose`, `gated_delta_rule_decode`,
and `gated_delta_rule_mtp`) before calling their respective helpers
(`run_pretranspose_decode`, `run_decode`, `run_mtp_decode`): detect whether the
native backend or helper is present (e.g., via the same availability predicate
used in the MTP fix or by checking the imported helper), and if missing raise a
clear ImportError/RuntimeError with guidance on installing the backend instead
of proceeding to call the run_* function; apply the same guard pattern to the
other affected ranges (the other occurrences around lines 382-401 and 550-577)
so no run_* helper is invoked when its backend is unavailable.

In `@flashinfer/gdn_kernels/gdn_decode_nontranspose.py`:
- Around line 541-544: Remove the dead expression and unused unpacking: delete
the stray statement that just reads h0_indices.layout.shape[0] (it discards the
result) and either use k_dim/v_dim where needed or stop unpacking unused k_dim
(e.g., unpack only the used dimensions or rename unused to _k_dim) so that
batch_hv_dim, v_dim and batch_size remain correct; update the block around
h0_source.layout.shape and the variables batch_hv_dim, k_dim, v_dim and
h0_indices to reflect the change.
- Around line 622-625: Remove the dead standalone expression
"h0_indices.layout.shape[0]" and either use or mark the unused k_dim from the
unpack of h0_source.layout.shape; specifically, replace the unused variable name
k_dim with _k_dim (or otherwise consume it) and delete the stray
h0_indices.layout.shape[0] line so only meaningful assignments remain
(batch_hv_dim, _k_dim, v_dim = h0_source.layout.shape and batch_size =
batch_hv_dim).

In `@flashinfer/gdn_kernels/gdn_decode_pretranspose.py`:
- Around line 799-813: The function _get_compiled_decode_kernel currently
returns an empty dict and uses its parameters only to form the functools.cache
key, which is confusing; update the function docstring to explicitly state that
B, T, H, HV, K, V, dtype, scale, and use_qk_l2norm are only used as cache keys
and that the function intentionally returns an empty placeholder (to be
populated later on first call) so readers understand the pattern and won't
expect parameter usage in the body.
- Around line 741-742: In the function that computes num_v_tiles (same pattern
as run_gdn_decode_kernel_small_batch_pretranspose), remove the dead standalone
expression "v_dim * k_dim * batch_size * 4 / 1024 / 1024" or convert it to a
meaningful use (e.g., assign to a clearly named variable like estimated_mb or
log it) so the computed value is not discarded; update the code around
num_v_tiles to either delete the stray expression or replace it with an
assignment or logging call to preserve the intent.
- Around line 638-639: Remove the dead expression that computes v_dim * k_dim *
batch_size * 4 / 1024 / 1024 (likely a leftover debug calculation) in
gdn_decode_pretranspose.py; either delete that line entirely or replace it with
an assigned variable (e.g., mem_mb) or a logging call if the computed memory in
MB is intended to be used—refer to the surrounding identifiers num_v_tiles,
v_dim, k_dim, batch_size, and TILE_V to locate the exact spot.
- Around line 317-585: The big-batch kernel
gdn_decode_kernel_big_batch_pretranspose and its launcher are never used because
run_func is unconditionally set to
run_gdn_decode_kernel_small_batch_pretranspose; either remove the unused
big-kernel and its launcher (clean up dead code) or restore a selection path:
update the launcher logic that sets run_func to choose between
run_gdn_decode_kernel_big_batch_pretranspose and
run_gdn_decode_kernel_small_batch_pretranspose based on batch-size (B) or a
config flag, or if you keep the big kernel for future use, add a clear TODO
comment above gdn_decode_kernel_big_batch_pretranspose and its launcher
explaining why it is retained and under what condition it should be re-enabled.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ae67bcd5-e9c4-4f73-bb19-38be9423c7e6

📥 Commits

Reviewing files that changed from the base of the PR and between 62ac6cb and bc4f98b.

📒 Files selected for processing (5)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/gdn_decode_mtp.py
  • flashinfer/gdn_kernels/gdn_decode_nontranspose.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py

Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

ameynaik-hub and others added 4 commits March 4, 2026 17:31
Improve GDN MTP (Multi-Token Processing) kernel performance by:

- Add instruction-level parallelism (ILP) with 1/2/4/8-row processing
- Enable shared memory caching for V tensor tiles
- Dynamic dispatch based on batch size and sequence length
- Change default: disable_state_update=False (h always updated)
- Keep original kernel for BS < 8 to avoid regression

This commit adds the optimized kernel while preserving the original kernel
for small batch sizes. The dispatch logic selects:
- BS < 8: Original kernel (no ILP overhead, matches baseline performance)
- BS >= 8: Optimized kernel with ILP rows and SMEM v caching

Performance (Qwen3-Next: q=k=16, v=32, d=128, FP32 state, cache ON,
state update ON, 1000 iterations):

Previous Kernel Times (us):

| BS  |  T=2  |  T=3  |  T=4  |  T=5  |  T=6  |  T=7  |  T=8  |
|-----|-------|-------|-------|-------|-------|-------|-------|
|   1 |  5.70 |  7.10 |  8.42 |  9.92 | 11.07 | 12.67 | 14.05 |
|   2 |  6.82 |  8.26 | 10.11 | 11.90 | 13.39 | 15.17 | 16.90 |
|   4 | 10.21 | 12.70 | 14.56 | 22.50 | 19.10 | 29.34 | 32.86 |
|   8 | 16.26 | 20.41 | 30.27 | 35.33 | 40.83 | 45.79 | 51.94 |
|  16 | 28.70 | 42.88 | 52.21 | 61.31 | 70.75 | 82.91 | 90.05 |
|  32 | 61.79 | 82.08 |104.83 |122.94 |148.58 |162.91 |184.38 |
|  64 |105.41 |137.47 |167.49 |201.09 |233.28 |266.94 |301.44 |
| 128 |192.58 |241.41 |296.42 |352.03 |409.98 |471.23 |532.85 |
| 256 |372.38 |472.18 |582.35 |703.45 |829.63 |958.45 |1094.30|
| 512 |744.24 |963.50 |1188.12|1439.23|1689.53|1954.30|2231.74|

New Kernel Times (us):

| BS  |  T=2  |  T=3  |  T=4  |  T=5  |  T=6  |  T=7  |  T=8  |
|-----|-------|-------|-------|-------|-------|-------|-------|
|   1 |  5.70 |  7.14 |  8.42 |  9.92 | 11.07 | 12.70 | 14.08 |
|   2 |  6.82 |  8.26 | 10.14 | 11.87 | 13.41 | 15.20 | 16.90 |
|   4 | 10.21 | 12.70 | 14.56 | 22.46 | 19.07 | 29.34 | 32.86 |
|   8 | 14.85 | 19.23 | 30.24 | 26.43 | 31.20 | 34.24 | 38.18 |
|  16 | 25.21 | 32.10 | 52.16 | 61.41 | 70.82 | 83.01 | 90.16 |
|  32 | 50.14 | 63.36 | 73.47 | 85.73 | 99.52 |113.76 |127.76 |
|  64 | 88.32 |116.48 |141.36 |165.18 |190.11 |214.72 |238.29 |
| 128 |169.54 |217.57 |264.86 |312.86 |360.96 |412.00 |458.97 |
| 256 |331.52 |425.28 |519.10 |617.12 |715.73 |822.96 |926.97 |
| 512 |658.56 |846.64 |1036.64|1235.76|1443.53|1660.64|1872.56|

Speedup vs Previous FlashInfer Kernel:

| BS  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
|-----|------|------|------|------|------|------|------|
|   1 | 1.00 | 0.99 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   2 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   4 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   8 | 1.09 | 1.06 | 1.00 | 1.34 | 1.31 | 1.34 | 1.36 |
|  16 | 1.14 | 1.34 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|  32 | 1.23 | 1.30 | 1.43 | 1.43 | 1.49 | 1.43 | 1.44 |
|  64 | 1.19 | 1.18 | 1.18 | 1.22 | 1.23 | 1.24 | 1.27 |
| 128 | 1.14 | 1.11 | 1.12 | 1.13 | 1.14 | 1.14 | 1.16 |
| 256 | 1.12 | 1.11 | 1.12 | 1.14 | 1.16 | 1.16 | 1.18 |
| 512 | 1.13 | 1.14 | 1.15 | 1.16 | 1.17 | 1.18 | 1.19 |

Small batch sizes (BS < 8): 1.00x (no regression).
Large batch sizes (BS >= 8): 1.00x-1.49x improvement, avg ~1.20x.
All 70 correctness tests pass (10 BS x 7 T values).

Also adds --update-state flag to bench_gdn_decode.py to test with
disable_state_update=False (h output updated after each chunk).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Replace 8 separate 1D register tensors (r_h, r_h2, ..., r_h8) with a
single 2D register tensor r_h[8, vec_size]. This is a code quality
improvement with no performance impact - the generated PTX is identical
since all indices are constexpr and loops are fully unrolled.

Changes:
- Use cute.make_layout((8, vec_size), stride=(vec_size, 1)) for r_h
- Access via r_h[row, i] instead of r_h{N}[i]
- Use cute.slice_(r_h, (row, None)) for copy operations

Verified: All 70 MTP benchmark configurations match previous results
within measurement noise (<0.5% variance).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…decode.py

Split the monolithic gdn_decode.py (4435 lines) into focused modules:
- gdn_kernels/gdn_decode_pretranspose.py: V-major state T=1 kernel (912 lines)
- gdn_kernels/gdn_decode_nontranspose.py: K-major state T=1 kernel (808 lines)
- gdn_kernels/gdn_decode_mtp.py: MTP T>1 kernels with dispatch (2445 lines)
- gdn_decode.py: thin API layer with validation and dtype handling (589 lines)

All 164 tests pass. Benchmark numbers unchanged.

AI-assisted (Claude Code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add initial_state/initial_state_indices parameters to
gated_delta_rule_decode_pretranspose, enabling direct pool-based
state gathering via the bf16 fast path (no caller-side gather/scatter).

Ported from eee401b which landed on main after our branch diverged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ameynaik-hub ameynaik-hub force-pushed the ameyn/improve_gdn_mtp_fp32state branch from bc4f98b to 71e774b Compare March 5, 2026 01:44
@ameynaik-hub
Copy link
Contributor Author

@kaixih can you please verify the merge?

@kahyunnam
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !349 has been updated with latest changes, and the CI pipeline #45381965 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45381965: 9/20 passed

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 7, 2026

tests seem clean

approval unblocked!

we can merge once public CI clears it

@aleozlx aleozlx added the ready label Mar 7, 2026
@aleozlx aleozlx merged commit 65d6e4a into flashinfer-ai:main Mar 7, 2026
48 of 67 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2618)

Improve GDN MTP (Multi-Token Processing) kernel performance by:

- Add instruction-level parallelism (ILP) with 1/2/4/8-row processing
- Enable shared memory caching for V tensor tiles
- Dynamic dispatch based on batch size and sequence length
- Change default: disable_state_update=False (h always updated)
- Keep original kernel for BS < 8 to avoid regression

This commit adds the optimized kernel while preserving the original
kernel
for small batch sizes. The dispatch logic selects:
- BS < 8: Original kernel (no ILP overhead, matches baseline
performance)
- BS >= 8: Optimized kernel with ILP rows and SMEM v caching

# Performance (Qwen3-Next: q=k=16, v=32, d=128, FP32 state, **cache
ON,state update ON**, 1000 iterations):

Previous Kernel Times (us):

| BS  |  T=2  |  T=3  |  T=4  |  T=5  |  T=6  |  T=7  |  T=8  |
|-----|-------|-------|-------|-------|-------|-------|-------|
|   1 |  5.70 |  7.10 |  8.42 |  9.92 | 11.07 | 12.67 | 14.05 |
|   2 |  6.82 |  8.26 | 10.11 | 11.90 | 13.39 | 15.17 | 16.90 |
|   4 | 10.21 | 12.70 | 14.56 | 22.50 | 19.10 | 29.34 | 32.86 |
|   8 | 16.26 | 20.41 | 30.27 | 35.33 | 40.83 | 45.79 | 51.94 |
|  16 | 28.70 | 42.88 | 52.21 | 61.31 | 70.75 | 82.91 | 90.05 |
|  32 | 61.79 | 82.08 |104.83 |122.94 |148.58 |162.91 |184.38 |
|  64 |105.41 |137.47 |167.49 |201.09 |233.28 |266.94 |301.44 |
| 128 |192.58 |241.41 |296.42 |352.03 |409.98 |471.23 |532.85 |
| 256 |372.38 |472.18 |582.35 |703.45 |829.63 |958.45 |1094.30|
| 512 |744.24 |963.50 |1188.12|1439.23|1689.53|1954.30|2231.74|

New Kernel Times (us):

| BS  |  T=2  |  T=3  |  T=4  |  T=5  |  T=6  |  T=7  |  T=8  |
|-----|-------|-------|-------|-------|-------|-------|-------|
|   1 |  5.70 |  7.14 |  8.42 |  9.92 | 11.07 | 12.70 | 14.08 |
|   2 |  6.82 |  8.26 | 10.14 | 11.87 | 13.41 | 15.20 | 16.90 |
|   4 | 10.21 | 12.70 | 14.56 | 22.46 | 19.07 | 29.34 | 32.86 |
|   8 | 14.85 | 19.23 | 30.24 | 26.43 | 31.20 | 34.24 | 38.18 |
|  16 | 25.21 | 32.10 | 52.16 | 61.41 | 70.82 | 83.01 | 90.16 |
|  32 | 50.14 | 63.36 | 73.47 | 85.73 | 99.52 |113.76 |127.76 |
|  64 | 88.32 |116.48 |141.36 |165.18 |190.11 |214.72 |238.29 |
| 128 |169.54 |217.57 |264.86 |312.86 |360.96 |412.00 |458.97 |
| 256 |331.52 |425.28 |519.10 |617.12 |715.73 |822.96 |926.97 |
| 512 |658.56 |846.64 |1036.64|1235.76|1443.53|1660.64|1872.56|

Speedup vs Previous FlashInfer Kernel:

| BS  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
|-----|------|------|------|------|------|------|------|
|   1 | 1.00 | 0.99 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   2 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   4 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   8 | 1.09 | 1.06 | 1.00 | 1.34 | 1.31 | 1.34 | 1.36 |
|  16 | 1.14 | 1.34 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|  32 | 1.23 | 1.30 | 1.43 | 1.43 | 1.49 | 1.43 | 1.44 |
|  64 | 1.19 | 1.18 | 1.18 | 1.22 | 1.23 | 1.24 | 1.27 |
| 128 | 1.14 | 1.11 | 1.12 | 1.13 | 1.14 | 1.14 | 1.16 |
| 256 | 1.12 | 1.11 | 1.12 | 1.14 | 1.16 | 1.16 | 1.18 |
| 512 | 1.13 | 1.14 | 1.15 | 1.16 | 1.17 | 1.18 | 1.19 |

Small batch sizes (BS < 8): 1.00x (no regression).
Large batch sizes (BS >= 8): 1.00x-1.49x improvement, avg ~1.20x.
All 70 correctness tests pass (10 BS x 7 T values).

Also adds --update-state flag to bench_gdn_decode.py to test with
disable_state_update=False (h output updated after each chunk).

# this is with MTP setting cache enabled but h update disabled so that
init state is not overwritten. cache ON, state update OFF

 ## Main Branch Kernel Times (µs)

| BS \ T | 2 | 3 | 4 | 5 | 6 | 7 | 8 |

|--------|--------|--------|--------|--------|--------|--------|--------|
| 1 | 5.70 | 7.04 | 8.29 | 9.76 | 10.94 | 12.53 | 13.98 |
| 2 | 6.56 | 8.13 | 9.86 | 11.62 | 13.15 | 14.91 | 16.61 |
| 4 | 9.57 | 11.78 | 14.05 | 16.18 | 18.53 | 28.72 | 23.65 |
| 8 | 14.24 | 17.79 | 28.16 | 33.23 | 38.90 | 44.06 | 49.71 |
| 16 | 23.07 | 38.29 | 46.70 | 55.87 | 64.50 | 74.48 | 84.77 |
| 32 | 41.60 | 67.65 | 84.58 | 101.52 | 121.17 | 136.34 | 154.26 |
| 64 | 77.60 | 113.41 | 142.64 | 175.58 | 205.82 | 236.59 | 274.37 |
| 128 | 148.74 | 206.43 | 260.35 | 317.28 | 371.31 | 430.32 | 490.70 |
| 256 | 281.86 | 403.75 | 503.42 | 617.65 | 738.83 | 862.19 | 998.87 |
| 512 | 549.49 | 809.55 | 1018.71 | 1276.85 | 1525.92 | 1794.26 |
2057.52 |

  ## Optimized Kernel Times (µs)

| BS \ T | 2 | 3 | 4 | 5 | 6 | 7 | 8 |

|--------|--------|--------|--------|--------|--------|--------|--------|
| 1 | 5.66 | 7.04 | 8.27 | 9.76 | 10.94 | 12.54 | 13.92 |
| 2 | 6.56 | 8.19 | 9.81 | 11.62 | 13.18 | 14.94 | 16.61 |
| 4 | 9.57 | 11.84 | 13.98 | 16.18 | 18.62 | 28.69 | 23.62 |
| 8 | 14.08 | 17.60 | 28.16 | 24.96 | 29.34 | 32.99 | 37.74 |
| 16 | 22.94 | 27.97 | 46.82 | 44.00 | 64.48 | 74.10 | 85.23 |
| 32 | 40.34 | 55.78 | 64.03 | 76.54 | 89.34 | 105.38 | 119.97 |
| 64 | 68.48 | 92.61 | 119.52 | 141.63 | 166.59 | 194.02 | 218.59 |
| 128 | 129.47 | 176.83 | 223.17 | 269.20 | 317.36 | 369.70 | 418.94 |
| 256 | 250.50 | 341.54 | 432.88 | 524.29 | 620.11 | 723.78 | 823.74 |
| 512 | 492.62 | 671.73 | 854.71 | 1041.25 | 1245.30 | 1458.29 | 1693.33
|

  ## Speedup (main_time / optimized_time)

  > Values > 1.0 = optimized kernel is faster

| BS \ T | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Avg |

|--------|-------|-------|-------|-------|-------|-------|-------|-------|
| 1 | 1.01 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
| 2 | 1.00 | 0.99 | 1.01 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
| 4 | 1.00 | 0.99 | 1.01 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
| 8 | 1.01 | 1.01 | 1.00 | **1.33** | **1.33** | **1.34** | **1.32** |
**1.19** |
| 16 | 1.01 | **1.37** | 1.00 | **1.27** | 1.00 | 1.01 | 0.99 | 1.09 |
| 32 | 1.03 | **1.21** | **1.32** | **1.33** | **1.36** | **1.29** |
**1.29** | **1.26** |
| 64 | 1.13 | **1.22** | **1.19** | **1.24** | **1.24** | **1.22** |
**1.26** | **1.21** |
| 128 | 1.15 | **1.17** | **1.17** | **1.18** | **1.17** | **1.16** |
**1.17** | **1.17** |
| 256 | 1.13 | **1.18** | **1.16** | **1.18** | **1.19** | **1.19** |
**1.21** | **1.18** |
| 512 | 1.12 | **1.21** | **1.19** | **1.23** | **1.23** | **1.23** |
**1.22** | **1.20** |

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* New public GDN decode APIs with runtime backend selection, flexible
state layout/dtype handling, and an MTP option controlling state updates
(default behavior changed).
* **Tests**
* Added tests exercising MTP path with FP32 state, cache enabled, and
state-update enabled.
* **Benchmarks**
* CLI and benchmarks gain an --update-state flag and report/update-state
behavior during MTP runs.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2618)

Improve GDN MTP (Multi-Token Processing) kernel performance by:

- Add instruction-level parallelism (ILP) with 1/2/4/8-row processing
- Enable shared memory caching for V tensor tiles
- Dynamic dispatch based on batch size and sequence length
- Change default: disable_state_update=False (h always updated)
- Keep original kernel for BS < 8 to avoid regression

This commit adds the optimized kernel while preserving the original
kernel
for small batch sizes. The dispatch logic selects:
- BS < 8: Original kernel (no ILP overhead, matches baseline
performance)
- BS >= 8: Optimized kernel with ILP rows and SMEM v caching

# Performance (Qwen3-Next: q=k=16, v=32, d=128, FP32 state, **cache
ON,state update ON**, 1000 iterations):

Previous Kernel Times (us):

| BS  |  T=2  |  T=3  |  T=4  |  T=5  |  T=6  |  T=7  |  T=8  |
|-----|-------|-------|-------|-------|-------|-------|-------|
|   1 |  5.70 |  7.10 |  8.42 |  9.92 | 11.07 | 12.67 | 14.05 |
|   2 |  6.82 |  8.26 | 10.11 | 11.90 | 13.39 | 15.17 | 16.90 |
|   4 | 10.21 | 12.70 | 14.56 | 22.50 | 19.10 | 29.34 | 32.86 |
|   8 | 16.26 | 20.41 | 30.27 | 35.33 | 40.83 | 45.79 | 51.94 |
|  16 | 28.70 | 42.88 | 52.21 | 61.31 | 70.75 | 82.91 | 90.05 |
|  32 | 61.79 | 82.08 |104.83 |122.94 |148.58 |162.91 |184.38 |
|  64 |105.41 |137.47 |167.49 |201.09 |233.28 |266.94 |301.44 |
| 128 |192.58 |241.41 |296.42 |352.03 |409.98 |471.23 |532.85 |
| 256 |372.38 |472.18 |582.35 |703.45 |829.63 |958.45 |1094.30|
| 512 |744.24 |963.50 |1188.12|1439.23|1689.53|1954.30|2231.74|

New Kernel Times (us):

| BS  |  T=2  |  T=3  |  T=4  |  T=5  |  T=6  |  T=7  |  T=8  |
|-----|-------|-------|-------|-------|-------|-------|-------|
|   1 |  5.70 |  7.14 |  8.42 |  9.92 | 11.07 | 12.70 | 14.08 |
|   2 |  6.82 |  8.26 | 10.14 | 11.87 | 13.41 | 15.20 | 16.90 |
|   4 | 10.21 | 12.70 | 14.56 | 22.46 | 19.07 | 29.34 | 32.86 |
|   8 | 14.85 | 19.23 | 30.24 | 26.43 | 31.20 | 34.24 | 38.18 |
|  16 | 25.21 | 32.10 | 52.16 | 61.41 | 70.82 | 83.01 | 90.16 |
|  32 | 50.14 | 63.36 | 73.47 | 85.73 | 99.52 |113.76 |127.76 |
|  64 | 88.32 |116.48 |141.36 |165.18 |190.11 |214.72 |238.29 |
| 128 |169.54 |217.57 |264.86 |312.86 |360.96 |412.00 |458.97 |
| 256 |331.52 |425.28 |519.10 |617.12 |715.73 |822.96 |926.97 |
| 512 |658.56 |846.64 |1036.64|1235.76|1443.53|1660.64|1872.56|

Speedup vs Previous FlashInfer Kernel:

| BS  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
|-----|------|------|------|------|------|------|------|
|   1 | 1.00 | 0.99 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   2 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   4 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|   8 | 1.09 | 1.06 | 1.00 | 1.34 | 1.31 | 1.34 | 1.36 |
|  16 | 1.14 | 1.34 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
|  32 | 1.23 | 1.30 | 1.43 | 1.43 | 1.49 | 1.43 | 1.44 |
|  64 | 1.19 | 1.18 | 1.18 | 1.22 | 1.23 | 1.24 | 1.27 |
| 128 | 1.14 | 1.11 | 1.12 | 1.13 | 1.14 | 1.14 | 1.16 |
| 256 | 1.12 | 1.11 | 1.12 | 1.14 | 1.16 | 1.16 | 1.18 |
| 512 | 1.13 | 1.14 | 1.15 | 1.16 | 1.17 | 1.18 | 1.19 |

Small batch sizes (BS < 8): 1.00x (no regression).
Large batch sizes (BS >= 8): 1.00x-1.49x improvement, avg ~1.20x.
All 70 correctness tests pass (10 BS x 7 T values).

Also adds --update-state flag to bench_gdn_decode.py to test with
disable_state_update=False (h output updated after each chunk).

# this is with MTP setting cache enabled but h update disabled so that
init state is not overwritten. cache ON, state update OFF

 ## Main Branch Kernel Times (µs)

| BS \ T | 2 | 3 | 4 | 5 | 6 | 7 | 8 |

|--------|--------|--------|--------|--------|--------|--------|--------|
| 1 | 5.70 | 7.04 | 8.29 | 9.76 | 10.94 | 12.53 | 13.98 |
| 2 | 6.56 | 8.13 | 9.86 | 11.62 | 13.15 | 14.91 | 16.61 |
| 4 | 9.57 | 11.78 | 14.05 | 16.18 | 18.53 | 28.72 | 23.65 |
| 8 | 14.24 | 17.79 | 28.16 | 33.23 | 38.90 | 44.06 | 49.71 |
| 16 | 23.07 | 38.29 | 46.70 | 55.87 | 64.50 | 74.48 | 84.77 |
| 32 | 41.60 | 67.65 | 84.58 | 101.52 | 121.17 | 136.34 | 154.26 |
| 64 | 77.60 | 113.41 | 142.64 | 175.58 | 205.82 | 236.59 | 274.37 |
| 128 | 148.74 | 206.43 | 260.35 | 317.28 | 371.31 | 430.32 | 490.70 |
| 256 | 281.86 | 403.75 | 503.42 | 617.65 | 738.83 | 862.19 | 998.87 |
| 512 | 549.49 | 809.55 | 1018.71 | 1276.85 | 1525.92 | 1794.26 |
2057.52 |

  ## Optimized Kernel Times (µs)

| BS \ T | 2 | 3 | 4 | 5 | 6 | 7 | 8 |

|--------|--------|--------|--------|--------|--------|--------|--------|
| 1 | 5.66 | 7.04 | 8.27 | 9.76 | 10.94 | 12.54 | 13.92 |
| 2 | 6.56 | 8.19 | 9.81 | 11.62 | 13.18 | 14.94 | 16.61 |
| 4 | 9.57 | 11.84 | 13.98 | 16.18 | 18.62 | 28.69 | 23.62 |
| 8 | 14.08 | 17.60 | 28.16 | 24.96 | 29.34 | 32.99 | 37.74 |
| 16 | 22.94 | 27.97 | 46.82 | 44.00 | 64.48 | 74.10 | 85.23 |
| 32 | 40.34 | 55.78 | 64.03 | 76.54 | 89.34 | 105.38 | 119.97 |
| 64 | 68.48 | 92.61 | 119.52 | 141.63 | 166.59 | 194.02 | 218.59 |
| 128 | 129.47 | 176.83 | 223.17 | 269.20 | 317.36 | 369.70 | 418.94 |
| 256 | 250.50 | 341.54 | 432.88 | 524.29 | 620.11 | 723.78 | 823.74 |
| 512 | 492.62 | 671.73 | 854.71 | 1041.25 | 1245.30 | 1458.29 | 1693.33
|

  ## Speedup (main_time / optimized_time)

  > Values > 1.0 = optimized kernel is faster

| BS \ T | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Avg |

|--------|-------|-------|-------|-------|-------|-------|-------|-------|
| 1 | 1.01 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
| 2 | 1.00 | 0.99 | 1.01 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
| 4 | 1.00 | 0.99 | 1.01 | 1.00 | 1.00 | 1.00 | 1.00 | 1.00 |
| 8 | 1.01 | 1.01 | 1.00 | **1.33** | **1.33** | **1.34** | **1.32** |
**1.19** |
| 16 | 1.01 | **1.37** | 1.00 | **1.27** | 1.00 | 1.01 | 0.99 | 1.09 |
| 32 | 1.03 | **1.21** | **1.32** | **1.33** | **1.36** | **1.29** |
**1.29** | **1.26** |
| 64 | 1.13 | **1.22** | **1.19** | **1.24** | **1.24** | **1.22** |
**1.26** | **1.21** |
| 128 | 1.15 | **1.17** | **1.17** | **1.18** | **1.17** | **1.16** |
**1.17** | **1.17** |
| 256 | 1.13 | **1.18** | **1.16** | **1.18** | **1.19** | **1.19** |
**1.21** | **1.18** |
| 512 | 1.12 | **1.21** | **1.19** | **1.23** | **1.23** | **1.23** |
**1.22** | **1.20** |

<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* New public GDN decode APIs with runtime backend selection, flexible
state layout/dtype handling, and an MTP option controlling state updates
(default behavior changed).
* **Tests**
* Added tests exercising MTP path with FP32 state, cache enabled, and
state-update enabled.
* **Benchmarks**
* CLI and benchmarks gain an --update-state flag and report/update-state
behavior during MTP runs.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
aleozlx added a commit that referenced this pull request Mar 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

fix api breaking changes for 0.6.7 release

## 🔍 Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

API changes since v0.6.6

  PR #2520 + commit e35c19e (fixed to be compatible)

  Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params
(after *). Backward-compatible.

  PR #2618 (has PR #2730 to fix it)

  Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True → Optional[bool] = None. Still
defaults to True at runtime but emits a deprecation
  warning; will flip to False in 0.7.0.

  PR #2775 (expected — cute DSL MoE cleanup)

  Function: blockscaled_contiguous_grouped_gemm_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

  Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

Function:
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: CuteDslMoEWrapper.__init__()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  Function: cute_dsl_fused_moe_nvfp4()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  PR #2428

  Function: rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor]; return type
torch.Tensor → None.

  Function: fused_add_rmsnorm_quant()
  Change: scale: float → scale: Union[float, torch.Tensor].

  Quantization functions (relocated, not removed)

All quantization APIs (fp4_quantize, block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize,
mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from
flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and
__init__.py exports are preserved. No breakage.

```diff
$ git diff v0.6.6 | grep -A20 "@flashinfer_api"                                               
     @flashinfer_api
@@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper:
         sinks: Optional[torch.Tensor] = None,
         q_len_per_req: Optional[int] = 1,
         skip_softmax_threshold_scale_factor: Optional[float] = None,
+        kv_block_scales: Optional[
+            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+        ] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Compute batch decode attention between query and paged kv cache.

@@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper:
             enable_pdl = device_support_pdl(q.device)
         k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)

+        # Unpack kv_block_scales
+        key_block_scales = None
+        value_block_scales = None
+        if kv_block_scales is not None:
+            if isinstance(kv_block_scales, tuple):
+                key_block_scales, value_block_scales = kv_block_scales
--
-@flashinfer_api
-def fp4_quantize(
-    input: torch.Tensor,
-    global_scale: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    sf_use_ue8m0: bool = False,
-    is_sf_swizzled_layout: bool = True,
-    is_sf_8x4_layout: bool = False,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to FP4 format.
-
-    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
-@flashinfer_api
-def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
-    """Swizzle block scale tensor for FP4 format.
-
-    This function swizzles the block scale tensor to optimize memory access patterns
-    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
-
-    Args:
-        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
-
-    Returns:
-        torch.Tensor: Swizzled tensor with the same shape as input.
-
-    Raises:
-        AssertionError: If input dtype is not uint8 or bfloat16.
-    """
-    # TODO(shuw): check input dtype is uint8
-    assert (
-        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
-    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
-
--
-@flashinfer_api
-def e2m1_and_ufp8sf_scale_to_float(
-    e2m1_tensor: torch.Tensor,
-    ufp8_scale_tensor: torch.Tensor,
-    global_scale_tensor: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    ufp8_type: int = 1,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
-
-    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
-    back to float values using the associated UFP8 scale factors and global scale.
-
-    Args:
-        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
-        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
-        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
-@flashinfer_api
-def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
-    """
-    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
-    """
-    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
-
-    return input_tensor[row_indices.to(input_tensor.device)]
-
-
-@flashinfer_api
-def shuffle_matrix_sf_a(
-    input_tensor: torch.Tensor,
-    epilogue_tile_m: int,
-    num_elts_per_sf: int = 16,
-):
-    """
-    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
-    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
-    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
-    layout.
-    This function expects the input to be in linear layout. It's done this
-    way because the scaling factors in the NVFP4 checkpoints are quantized
-    and are in linear layout.
-    This function doesn't add padding.
-    """
-
-    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
-
-    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
-
--
-@flashinfer_api
-def nvfp4_quantize(
-    a,
-    a_global_sf,
-    sfLayout=SfLayout.layout_128x4,
-    do_shuffle=False,
-    sf_vec_size=16,
-    enable_pdl=None,
-):
-    """
-    Quantize input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
-        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-
--
-@flashinfer_api
-def mxfp4_quantize(a):
-    """
-    Quantize input tensor to MXFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-            - Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-    """
-    a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
-    a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
-    return a_fp4, a_sf
-
-
-@flashinfer_api
-def mxfp4_dequantize(a_fp4, a_sf):
-    """
-    Dequantize input tensor from MXFP4 format.
-
-    Parameters:
-        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    return e2m1_and_ufp8sf_scale_to_float(
-        a_fp4.cpu().view(torch.uint8),
-        a_sf.cpu().view(torch.uint8).reshape(-1),
-        torch.tensor([1.0], device=a_fp4.device),
-        32,
-        0,
-        True,
-    )
-
--
-@flashinfer_api
-def mxfp4_dequantize_host(
-    weight: torch.Tensor,
-    scale: torch.Tensor,
-    group_size: int = 32,
-) -> torch.Tensor:
-    """
-    Dequantize input tensor from MXFP4 format on host.
-
-    Parameters:
-        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-        group_size (int, optional): Group size for dequantization. Defaults to 32.
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
-    major, minor = get_compute_capability(
-        torch.device("cuda:0")
-    )  # use any cuda device to get a compute capability
--
-@flashinfer_api
-def nvfp4_batched_quantize(
-    a,
-    a_global_sf,
-    sf_vec_size=16,
-):
-    """
-    Quantize batched input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
--
-@flashinfer_api
-def scaled_fp4_grouped_quantize(
-    a,
-    mask,
-    a_global_sf,
-):
-    """
-    quantize batched input tensor to NVFP4 format with mask.
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        mask (torch.Tensor): Mask tensor to apply before quantization.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
-    a_fp4, a_sf = get_fp4_quantization_module(
-        device_arch
--
-@flashinfer_api
-def mxfp8_quantize(
-    input: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-    alignment: int = 32,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to MxFP8 format.
-
-    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
-        alignment (int, optional): sfVecSize. Defaults to 32.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
--
-@flashinfer_api
-def mxfp8_dequantize_host(
-    input: torch.Tensor,
-    scale_tensor: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Dequantize input tensor from MxFP8 format.
-
-    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
-    back to float values using the associated scale factors.
-
-    Args:
-        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
-        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
-
-    Returns:
-        torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
-
-    """
-
--
-@flashinfer_api
 def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     vectorized_f32: bool = True,
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads.

@@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     major, minor = get_compute_capability(a.device)
     if major != 10:
         raise ValueError(
-            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). "
+            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). "
             f"Got SM{major}{minor}."
         )

--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (128, 128),
-    cluster_shape_mn: Tuple[int, int] = (1, 1),
-    sm_count: Optional[int] = None,
-) -> torch.Tensor:
-    """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization.
-
--
-@flashinfer_api
 def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     cluster_shape_mn: Tuple[int, int] = (2, 1),
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads.

@@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
             expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
         token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
         out: Optional output tensor, shape (seq_len, n). Created if None.
-             This tensor is used for atomic accumulation, so it should be zero-initialized.
+             This tensor is used for atomic accumulation. If `out` is
+             provided, it must already be zero-initialized by the caller.
+             If `out` is None, this function allocates a zero-initialized
+             output tensor. Passing a non-zeroed `out` buffer will silently
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    out_scale: Optional[torch.Tensor] = None,
-    global_scale: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (256, 128),
-    cluster_shape_mn: Tuple[int, int] = (2, 1),
-    vectorized_f32: bool = True,
-    sm_count: Optional[int] = None,
--
     @flashinfer_api
     def __init__(
         self,
@@ -347,6 +355,7 @@ class CuteDslMoEWrapper:
         sf_vec_size: int = 16,
         output_dtype: torch.dtype = torch.bfloat16,
         device: str = "cuda",
+        enable_pdl: bool = True,
     ):
         """Initialize the MoE wrapper.

@@ -363,6 +372,7 @@ class CuteDslMoEWrapper:
             sf_vec_size: Scale factor vector size. Default: 16.
             output_dtype: Output data type. Default: torch.bfloat16.
             device: Device for buffer allocation. Default: "cuda".
+            enable_pdl: Enable Programmatic Dependent Launch. Default: True.
         """
         self.num_experts = num_experts
         self.top_k = top_k
@@ -376,6 +386,7 @@ class CuteDslMoEWrapper:
         self.sf_vec_size = sf_vec_size
--
     @flashinfer_api
@@ -550,9 +570,10 @@ class CuteDslMoEWrapper:
                 f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
             )

-        # Allocate output buffer if not using pre-allocated one
+        # Slice the pre-allocated buffer to the active batch so that
+        # _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
         if self.use_cuda_graph:
-            moe_output = self._moe_output
+            moe_output = self._moe_output[:num_tokens]
         else:
             moe_output = torch.empty(
                 (num_tokens, self.hidden_size),
@@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Internal implementation called by auto-tuner for functional API."""
--
 @flashinfer_api
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
@@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Run fused MoE computation using CuteDSL NVFP4 kernels.

+    Supported architectures: SM100, SM103.
+
     This is the simple functional API. For CUDA graph support, use
     `CuteDslMoEWrapper` instead.

@@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4(
         local_expert_offset=local_expert_offset,
         use_fused_finalize=use_fused_finalize,
         output_dtype=output_dtype,
+        enable_pdl=enable_pdl,
--
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
@@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose(
         - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
           and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
           (supports both the direct ``state`` path and the pool+indices path).
-        - pool+indices (``initial_state``/``initial_state_indices``) only supported
-          via the bf16 fast path; float32 state raises an error.
+        - pool+indices (``initial_state``/``initial_state_indices``) supported on
+          both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
+          (T=1). The float32 path also supports negative indices for padding.
         - Legacy path (float32 state, T=1): K and V must be multiples of 4.
     """
     # Validate input shapes
@@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose(
         return_state = initial_state if use_pool else state
         return output, return_state

-    # Legacy path: T=1 only, float32 state (no pool+indices support)
-    assert not use_pool, (
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2427,7 +489,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: Optional[bool] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
@@ -2463,8 +525,15 @@ def gated_delta_rule_mtp(
         intermediate_states_buffer (Optional[torch.Tensor]):
             Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``.
             If None, intermediate states are not cached.
-        disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``True``.
+        disable_state_update (Optional[bool]):
+            If True, the initial state is not updated. Currently defaults to ``True``.
+            Please pass this argument explicitly — the default will change to ``False``
--
 @flashinfer_api
@@ -60,16 +120,14 @@ def rmsnorm(
     output: torch.Tensor
         Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
     """
-    if enable_pdl is None:
-        enable_pdl = device_support_pdl(input.device)
     if out is None:
         out = torch.empty_like(input)
-    _rmsnorm(out, input, weight, eps, enable_pdl)
+    _rmsnorm_impl(out, input, weight, eps, enable_pdl)
     return out


 @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
-def _rmsnorm(
+def _rmsnorm_impl(
     out: torch.Tensor,
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -78,11 +136,21 @@ def _rmsnorm(
--
 @flashinfer_api
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
@@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek(
         If return_lse is False, the output will be a single tensor.
     """
     if not is_sm12x_supported(query.device):
-        major, minor = get_compute_capability(query.device)
-        if major == 12:
-            min_cuda = "13.0" if minor >= 1 else "12.8"
-            raise ValueError(
-                f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
-                f"for SM12{minor}x GPUs."
-            )
         raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
     assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
         "currently only support deepseek r1 192 query and 128 value"
     )
-    module = get_trtllm_fmha_v2_module()
+    module = get_trtllm_fmha_v2_sm120_module()
     is_e4m3 = query.dtype == torch.float8_e4m3fn
--
+@flashinfer_api
+def trtllm_fmha_v2_prefill(
+    qkv: Union[
+        torch.Tensor,
+        Tuple[torch.Tensor, torch.Tensor],
+        Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+    ],
+    input_layout: str,
+    workspace_buffer: torch.Tensor,
+    seq_lens: torch.Tensor,
+    max_q_len: int,
+    max_kv_len: int,
+    bmm1_scale: float,
+    bmm2_scale: float,
+    batch_size: int,
+    cum_seq_lens_q: torch.Tensor,
+    cum_seq_lens_kv: torch.Tensor,
+    block_tables: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    out_dtype: Optional[Union[torch.dtype, str]] = None,
+    sinks: Optional[List[torch.Tensor]] = None,
--
+@flashinfer_api
+def fp4_quantize(
+    input: torch.Tensor,
+    global_scale: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    sf_use_ue8m0: bool = False,
+    is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to FP4 format.
+
+    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
+@flashinfer_api
+def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
+    """Swizzle block scale tensor for FP4 format.
+
+    This function swizzles the block scale tensor to optimize memory access patterns
+    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
+
+    Args:
+        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
+
+    Returns:
+        torch.Tensor: Swizzled tensor with the same shape as input.
+
+    Raises:
+        AssertionError: If input dtype is not uint8 or bfloat16.
+    """
+    # TODO(shuw): check input dtype is uint8
+    assert (
+        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
+    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
+
--
+@flashinfer_api
+def e2m1_and_ufp8sf_scale_to_float(
+    e2m1_tensor: torch.Tensor,
+    ufp8_scale_tensor: torch.Tensor,
+    global_scale_tensor: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    ufp8_type: int = 1,
+    is_sf_swizzled_layout: bool = True,
+) -> torch.Tensor:
+    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
+
+    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
+    back to float values using the associated UFP8 scale factors and global scale.
+
+    Args:
+        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
+        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
+        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
+        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
+@flashinfer_api
+def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
+    """
+    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
+    """
+    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
+
+    return input_tensor[row_indices.to(input_tensor.device)]
+
+
+@flashinfer_api
+def shuffle_matrix_sf_a(
+    input_tensor: torch.Tensor,
+    epilogue_tile_m: int,
+    num_elts_per_sf: int = 16,
+):
+    """
+    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
+    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
+    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
+    layout.
+    This function expects the input to be in linear layout. It's done this
+    way because the scaling factors in the NVFP4 checkpoints are quantized
+    and are in linear layout.
+    This function doesn't add padding.
+    """
+
+    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
+
+    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
+
--
+@flashinfer_api
+def nvfp4_quantize(
+    a,
+    a_global_sf,
+    sfLayout=SfLayout.layout_128x4,
+    do_shuffle=False,
+    sf_vec_size=16,
+    enable_pdl=None,
+):
+    """
+    Quantize input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
+        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability. Defaults to None.
+
--
+@flashinfer_api
+def mxfp4_quantize(
+    a: torch.Tensor,
+    backend: str = "cuda",
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        backend (str, optional): Backend to use for quantization.
+            - "cuda": Use CUDA kernel (default, stable)
+            - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**)
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic
+            Dependent Launch). Only used when backend="cute-dsl".
+            If None, automatically detects based on device capability.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
--
+@flashinfer_api
+def mxfp4_dequantize(a_fp4, a_sf):
+    """
+    Dequantize input tensor from MXFP4 format.
+
+    Parameters:
+        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    return e2m1_and_ufp8sf_scale_to_float(
+        a_fp4.cpu().view(torch.uint8),
+        a_sf.cpu().view(torch.uint8).reshape(-1),
+        torch.tensor([1.0], device=a_fp4.device),
+        32,
+        0,
+        True,
+    )
+
--
+@flashinfer_api
+def mxfp4_dequantize_host(
+    weight: torch.Tensor,
+    scale: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """
+    Dequantize input tensor from MXFP4 format on host.
+
+    Parameters:
+        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+        group_size (int, optional): Group size for dequantization. Defaults to 32.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
+    major, minor = get_compute_capability(
+        torch.device("cuda:0")
+    )  # use any cuda device to get a compute capability
--
+@flashinfer_api
+def nvfp4_batched_quantize(
+    a,
+    a_global_sf,
+    sf_vec_size=16,
+):
+    """
+    Quantize batched input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
--
+@flashinfer_api
+def nvfp4_quantize_paged_kv_cache(
+    k_cache: torch.Tensor,
+    v_cache: torch.Tensor,
+    kv_layout: str = "HND",
+    k_global_sf: Optional[torch.Tensor] = None,
+    v_global_sf: Optional[torch.Tensor] = None,
+) -> Tuple[
+    Tuple[torch.Tensor, torch.Tensor],
+    Tuple[torch.Tensor, torch.Tensor],
+    float,
+    float,
+]:
+    """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA.
+
+    Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling
+    (global FP32 + per-block FP8), and swizzles scale factors
+    for the SM100 trtllm-gen MHA kernel layout.
+
+    Args:
+        k_cache: Key cache tensor.
--
+@flashinfer_api
+def scaled_fp4_grouped_quantize(
+    a,
+    mask,
+    a_global_sf,
+):
+    """
+    quantize batched input tensor to NVFP4 format with mask.
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        mask (torch.Tensor): Mask tensor to apply before quantization.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
+    a_fp4, a_sf = get_fp4_quantization_module(
+        device_arch
--
+@flashinfer_api
+def nvfp4_kv_dequantize(
+    fp4_data: torch.Tensor,
+    block_scales: torch.Tensor,
+    global_scale: torch.Tensor,
+    output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+    """GPU dequantization of NVFP4 KV cache data with linear block scale layout.
+
+    Requires SM80+.
+
+    Args:
+        fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+        block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]``
+            with dtype uint8.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as fp4_data.
+        output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype.
--
+@flashinfer_api
+def nvfp4_kv_quantize(
+    input: torch.Tensor,
+    global_scale: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """GPU quantization to NVFP4 KV cache format with linear block scale layout.
+
+    Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16.
+            K must be divisible by 16.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as input.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]:
+            - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+            - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8.
+    """
+    M, K = input.shape
--
+@flashinfer_api
+def mxfp8_quantize(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: Optional[bool] = None,
+    backend: Literal["cuda", "cute-dsl"] = "cuda",
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to MxFP8 format.
+
+    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        alignment (int, optional): sfVecSize. Defaults to 32.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0). Defaults to None.
+        backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are:
--
+@flashinfer_api
+def mxfp8_dequantize_host(
+    input: torch.Tensor,
+    scale_tensor: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> torch.Tensor:
+    """Dequantize input tensor from MxFP8 format.
+
+    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
+    back to float values using the associated scale factors.
+
+    Args:
+        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
+        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors.
+            If provided,it overrides is_sf_swizzled_layout. Defaults to None.
+            Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear.
+
+    Returns:
--
+@flashinfer_api
+def mxfp4_quantize_cute_dsl(
+    input: torch.Tensor,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format using CuTe-DSL kernel.
+
+    This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
+    - Global scale computed as (448 * 6) / max(|input|)
+    - UE8M0 scale factors
+    - E2M1 output format (4-bit, 2 values per byte)
+    - Swizzled (128x4) scale factor layout
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0).
--
+@flashinfer_api
+def mxfp8_quantize_cute_dsl(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP8 format using CuTe-DSL kernel.
+
+    This is a GPU implementation with dual-path optimization:
+    - LINEAR layout: SF-block based iteration (fast)
+    - SWIZZLED layout: Row-based iteration with padding fast path (optimized)
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False)
+        alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)
```


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Enhancements**
* Normalization now accepts scale as either a float or tensor; passing a
float emits a deprecation warning and is auto-converted for
compatibility.
* Attention/decoding API: cache-scale parameters are now optional
keyword-only arguments with sensible defaults, simplifying common call
patterns.
* **Tests**
* Tests updated to match the adjusted attention/decoding call signature.
* **Chores**
  * Release version bumped to 0.6.7.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants