Skip to content

Deprecation for gated_delta_rule_mtp's intermediate_states_buffer=True#2730

Merged
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
kahyunnam:knam/gated_delta_rule_mtp
Mar 20, 2026
Merged

Deprecation for gated_delta_rule_mtp's intermediate_states_buffer=True#2730
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
kahyunnam:knam/gated_delta_rule_mtp

Conversation

@kahyunnam
Copy link
Collaborator

@kahyunnam kahyunnam commented Mar 9, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Deprecations
    • A deprecation notice has been introduced that alerts users when relying on default behavior. The current default is scheduled to change in a future release.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request prepares for a future breaking change in the gated_delta_rule_mtp function by deprecating the implicit default behavior of the disable_state_update parameter. It introduces a warning mechanism to inform users about the upcoming change in FlashInfer 0.7.0, encouraging them to explicitly define the parameter's value to avoid unexpected behavior.

Highlights

  • Deprecation Warning: Introduced a deprecation warning for the implicit default value of the disable_state_update parameter in the gated_delta_rule_mtp function. The default will change from True to False in FlashInfer 0.7.0.
  • Parameter Type Update: Modified the type hint for disable_state_update from bool to Optional[bool] to accommodate the new deprecation logic.
  • Docstring Enhancement: Updated the docstring for disable_state_update to clearly communicate the upcoming default change and advise users to explicitly pass the argument.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_decode.py
    • Imported logger for emitting warnings.
    • Updated the disable_state_update parameter's type hint to Optional[bool].
    • Enhanced the docstring for disable_state_update to include deprecation details and guidance.
    • Implemented runtime logic to issue a warning when disable_state_update is not explicitly provided, defaulting it to True for backward compatibility in the current version.
Activity
  • No human activity (comments, reviews, or progress updates) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 9, 2026

📝 Walkthrough

Walkthrough

This PR adds deprecation handling for the disable_state_update parameter in gated_delta_rule_mtp. The parameter signature changes from bool = False to Optional[bool] = None, with deprecation warnings emitted when the parameter is not explicitly provided and documentation updated to reflect the planned default behavior change.

Changes

Cohort / File(s) Summary
Deprecation Handling
flashinfer/gdn_decode.py
Updated gated_delta_rule_mtp function: changed disable_state_update parameter from bool = False to Optional[bool] = None. Added logger import, deprecation warning emission when parameter is None, and updated docstring to document the deprecation and future default value change.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐰 Whiskers twitch with deprecation glee,
A gentle warning for what will be,
From False to None, a softer way,
Default shall change another day! 🔔

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Title check ⚠️ Warning The title references 'intermediate_states_buffer=True' but the actual change is about 'disable_state_update' parameter deprecation - the title is misleading and does not accurately reflect the primary change. Update the title to reflect the actual change, such as 'Deprecate implicit disable_state_update default in gated_delta_rule_mtp' or 'Add deprecation handling for gated_delta_rule_mtp disable_state_update parameter'.
Description check ⚠️ Warning The PR description contains only the template structure with uncompleted sections; the 📌 Description section is empty with no explanation of what the PR does or why the changes are needed. Fill in the 📌 Description section with a clear explanation of the deprecation changes, why disable_state_update behavior is being deprecated, and how the new Optional[bool] parameter affects existing code.
✅ Passed checks (1 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kahyunnam kahyunnam changed the title deprecation for intermediate_states_buffer=True Deprecation for gated_delta_rule_mtp's intermediate_states_buffer=True Mar 9, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a deprecation warning for the disable_state_update parameter in gated_delta_rule_mtp to resolve an inconsistency between the function signature and its docstring. While the implementation is correct, I've raised a concern about the deprecation strategy, which seems to introduce a temporary breaking change that will be reverted in the future. My suggestion is to adopt a more direct path to the desired long-term default value to avoid user confusion. Also, the pull request title seems to have a typo and should probably be updated to reflect the change to disable_state_update.

@kahyunnam kahyunnam self-assigned this Mar 9, 2026
@kahyunnam kahyunnam added the v0.6.6 release blocker label for 0.6.6 label Mar 9, 2026
@kahyunnam kahyunnam marked this pull request as ready for review March 9, 2026 19:10
@kahyunnam kahyunnam requested a review from aleozlx March 9, 2026 19:11
Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm from my perspective

@aleozlx aleozlx added the run-ci label Mar 9, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Mar 9, 2026

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gdn_decode.py (2)

479-494: 🛠️ Refactor suggestion | 🟠 Major

Add the repo-standard backend gate to this public API.

gated_delta_rule_mtp still advertises an SM90 requirement, but it does not expose the expected backend capability guard/helpers at the API boundary.

As per coding guidelines, "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 479 - 494, The public API
gated_delta_rule_mtp currently requires SM90 but lacks the repo-standard backend
guard; update the symbol gated_delta_rule_mtp to be annotated with the
`@backend_requirement`(...) decorator (configured for SM90) and ensure the module
exposes/uses the helper predicates is_compute_capability_supported(cc) and
is_backend_supported() so callers can check availability; import the
backend_requirement decorator, add the decorator to gated_delta_rule_mtp, and if
missing implement or forward-export the two helper functions
(is_compute_capability_supported and is_backend_supported) so the API boundary
follows the repository guideline.

491-559: ⚠️ Potential issue | 🟠 Major

This still doesn't deprecate the intermediate_states_buffer=True path.

The new branch only warns on omitted disable_state_update. If callers still use the intermediate_states_buffer=True shorthand described by the PR objective, they will skip this branch, enter cache_intermediate_states, and then fail on .shape[0] at Line 608 because True is a bool, not a tensor. Please intercept that compatibility case before the buffer-shape logic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 491 - 559, The code fails when callers
pass intermediate_states_buffer=True because later code expects a tensor and
accesses intermediate_states_buffer.shape[0]; update gated_delta_rule_mtp to
intercept the compatibility shorthand before any buffer-shape logic: if
intermediate_states_buffer is True (or isinstance(intermediate_states_buffer,
bool) and intermediate_states_buffer), set a local flag
cache_intermediate_states=True, replace intermediate_states_buffer with None (or
a properly allocated tensor later) and emit a deprecation/warning, then continue
to the existing code paths that handle caching (e.g., cache_intermediate_states
and subsequent uses of intermediate_states_buffer.shape); ensure you reference
and adjust the variables intermediate_states_buffer and
cache_intermediate_states so no bool is used where a tensor is expected.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)

551-559: Please add a regression test for the omitted-argument branch.

In the provided context, tests/gdn/test_decode_delta_rule.py:830-930 and benchmarks/bench_gdn_decode.py:1230-1270 already pass disable_state_update explicitly, so this new None path does not appear covered. A focused test should lock down both the current True behavior and the one-time deprecation signal when the argument is omitted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 551 - 559, Add a regression test that
calls gated_delta_rule_mtp() without passing disable_state_update to exercise
the None branch: assert that disable_state_update is implicitly set to True
(behaviorally by observing state updates), and capture/expect the one-time
deprecation warning emitted via logger.warning_once (use the test warning
capture or mocking facility to ensure it appears exactly once). Put the test
alongside the GDN decode tests, name it to indicate “omitted
disable_state_update” behavior, and ensure other existing tests that pass
disable_state_update explicitly remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 479-494: The public API gated_delta_rule_mtp currently requires
SM90 but lacks the repo-standard backend guard; update the symbol
gated_delta_rule_mtp to be annotated with the `@backend_requirement`(...)
decorator (configured for SM90) and ensure the module exposes/uses the helper
predicates is_compute_capability_supported(cc) and is_backend_supported() so
callers can check availability; import the backend_requirement decorator, add
the decorator to gated_delta_rule_mtp, and if missing implement or
forward-export the two helper functions (is_compute_capability_supported and
is_backend_supported) so the API boundary follows the repository guideline.
- Around line 491-559: The code fails when callers pass
intermediate_states_buffer=True because later code expects a tensor and accesses
intermediate_states_buffer.shape[0]; update gated_delta_rule_mtp to intercept
the compatibility shorthand before any buffer-shape logic: if
intermediate_states_buffer is True (or isinstance(intermediate_states_buffer,
bool) and intermediate_states_buffer), set a local flag
cache_intermediate_states=True, replace intermediate_states_buffer with None (or
a properly allocated tensor later) and emit a deprecation/warning, then continue
to the existing code paths that handle caching (e.g., cache_intermediate_states
and subsequent uses of intermediate_states_buffer.shape); ensure you reference
and adjust the variables intermediate_states_buffer and
cache_intermediate_states so no bool is used where a tensor is expected.

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 551-559: Add a regression test that calls gated_delta_rule_mtp()
without passing disable_state_update to exercise the None branch: assert that
disable_state_update is implicitly set to True (behaviorally by observing state
updates), and capture/expect the one-time deprecation warning emitted via
logger.warning_once (use the test warning capture or mocking facility to ensure
it appears exactly once). Put the test alongside the GDN decode tests, name it
to indicate “omitted disable_state_update” behavior, and ensure other existing
tests that pass disable_state_update explicitly remain unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bca4fa26-468b-4609-ae01-8520654633a2

📥 Commits

Reviewing files that changed from the base of the PR and between 2bb3e9e and 4d551eb.

📒 Files selected for processing (1)
  • flashinfer/gdn_decode.py

@flashinfer-bot
Copy link
Collaborator

GitLab MR !393 has been created, and the CI pipeline #45733820 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx mentioned this pull request Mar 9, 2026
5 tasks
@kahyunnam
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !393 has been created, and the CI pipeline #45735289 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx enabled auto-merge (squash) March 9, 2026 21:32
yzh119 pushed a commit that referenced this pull request Mar 9, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 "Gated-by" PR list


https://github.com/flashinfer-ai/flashinfer/pulls?q=is%3Aopen+is%3Apr+label%3Av0.6.6

#2730 Resolving
gated_delta_rule_mtp breaking change before release

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

```diff
$ git diff v0.6.5 | grep -A20 @flashinfer_api                                                                                                                                                                         (version_bump✱)
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
-    state: torch.Tensor,
+    state: Optional[torch.Tensor],
     A_log: torch.Tensor,
     a: torch.Tensor,
     dt_bias: torch.Tensor,
@@ -951,6 +113,8 @@ def gated_delta_rule_decode_pretranspose(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     use_qk_l2norm: bool = True,
+    initial_state: Optional[torch.Tensor] = None,
+    initial_state_indices: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Gated Delta Rule Decode kernel for single-token generation.

@@ -964,10 +128,11 @@ def gated_delta_rule_decode_pretranspose(
             Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16.
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2398,7 +487,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: bool = False,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
```

* gated_delta_rule_mtp will be resolved in PR
#2730

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Version bump to 0.6.6

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45735289: 8/20 passed

@aleozlx aleozlx added v0.6.7 release blocker label for 0.6.7 and removed v0.6.6 release blocker label for 0.6.6 labels Mar 11, 2026
@kahyunnam
Copy link
Collaborator Author

@aleozlx does this still need to be merged? A bit confused about the current state of release plan

frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 "Gated-by" PR list


https://github.com/flashinfer-ai/flashinfer/pulls?q=is%3Aopen+is%3Apr+label%3Av0.6.6

flashinfer-ai#2730 Resolving
gated_delta_rule_mtp breaking change before release

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

```diff
$ git diff v0.6.5 | grep -A20 @flashinfer_api                                                                                                                                                                         (version_bump✱)
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
-    state: torch.Tensor,
+    state: Optional[torch.Tensor],
     A_log: torch.Tensor,
     a: torch.Tensor,
     dt_bias: torch.Tensor,
@@ -951,6 +113,8 @@ def gated_delta_rule_decode_pretranspose(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     use_qk_l2norm: bool = True,
+    initial_state: Optional[torch.Tensor] = None,
+    initial_state_indices: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Gated Delta Rule Decode kernel for single-token generation.

@@ -964,10 +128,11 @@ def gated_delta_rule_decode_pretranspose(
             Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16.
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2398,7 +487,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: bool = False,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
```

* gated_delta_rule_mtp will be resolved in PR
flashinfer-ai#2730

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Version bump to 0.6.6

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 "Gated-by" PR list

https://github.com/flashinfer-ai/flashinfer/pulls?q=is%3Aopen+is%3Apr+label%3Av0.6.6

flashinfer-ai#2730 Resolving
gated_delta_rule_mtp breaking change before release

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

```diff
$ git diff v0.6.5 | grep -A20 @flashinfer_api                                                                                                                                                                         (version_bump✱)
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
-    state: torch.Tensor,
+    state: Optional[torch.Tensor],
     A_log: torch.Tensor,
     a: torch.Tensor,
     dt_bias: torch.Tensor,
@@ -951,6 +113,8 @@ def gated_delta_rule_decode_pretranspose(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     use_qk_l2norm: bool = True,
+    initial_state: Optional[torch.Tensor] = None,
+    initial_state_indices: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""Gated Delta Rule Decode kernel for single-token generation.

@@ -964,10 +128,11 @@ def gated_delta_rule_decode_pretranspose(
             Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16.
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2398,7 +487,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: bool = False,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
```

* gated_delta_rule_mtp will be resolved in PR
flashinfer-ai#2730

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
  * Version bump to 0.6.6

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@aleozlx
Copy link
Collaborator

aleozlx commented Mar 20, 2026

we still need to merge it

0.6.6 ended up being cut before the breaking changes because the turn-around was unacceptable

so this fix is to be merged before we cut 0.6.7 again

wrt deprecation - warning is to be posted thru 0.7 and 0.8 can ship a breaking change announced since 0.6.x

@aleozlx aleozlx added run-ci and removed run-ci labels Mar 20, 2026
@aleozlx aleozlx merged commit 7cb016d into flashinfer-ai:main Mar 20, 2026
69 of 162 checks passed
ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 22, 2026
Resolve conflict in flashinfer/gdn_decode.py by accepting main's
disable_state_update deprecation warning from PR flashinfer-ai#2730.

AI-assisted: Claude Code

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
aleozlx added a commit that referenced this pull request Mar 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

fix api breaking changes for 0.6.7 release

## 🔍 Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

API changes since v0.6.6

  PR #2520 + commit e35c19e (fixed to be compatible)

  Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params
(after *). Backward-compatible.

  PR #2618 (has PR #2730 to fix it)

  Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True → Optional[bool] = None. Still
defaults to True at runtime but emits a deprecation
  warning; will flip to False in 0.7.0.

  PR #2775 (expected — cute DSL MoE cleanup)

  Function: blockscaled_contiguous_grouped_gemm_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

  Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

Function:
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: CuteDslMoEWrapper.__init__()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  Function: cute_dsl_fused_moe_nvfp4()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  PR #2428

  Function: rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor]; return type
torch.Tensor → None.

  Function: fused_add_rmsnorm_quant()
  Change: scale: float → scale: Union[float, torch.Tensor].

  Quantization functions (relocated, not removed)

All quantization APIs (fp4_quantize, block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize,
mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from
flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and
__init__.py exports are preserved. No breakage.

```diff
$ git diff v0.6.6 | grep -A20 "@flashinfer_api"                                               
     @flashinfer_api
@@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper:
         sinks: Optional[torch.Tensor] = None,
         q_len_per_req: Optional[int] = 1,
         skip_softmax_threshold_scale_factor: Optional[float] = None,
+        kv_block_scales: Optional[
+            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+        ] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Compute batch decode attention between query and paged kv cache.

@@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper:
             enable_pdl = device_support_pdl(q.device)
         k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)

+        # Unpack kv_block_scales
+        key_block_scales = None
+        value_block_scales = None
+        if kv_block_scales is not None:
+            if isinstance(kv_block_scales, tuple):
+                key_block_scales, value_block_scales = kv_block_scales
--
-@flashinfer_api
-def fp4_quantize(
-    input: torch.Tensor,
-    global_scale: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    sf_use_ue8m0: bool = False,
-    is_sf_swizzled_layout: bool = True,
-    is_sf_8x4_layout: bool = False,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to FP4 format.
-
-    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
-@flashinfer_api
-def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
-    """Swizzle block scale tensor for FP4 format.
-
-    This function swizzles the block scale tensor to optimize memory access patterns
-    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
-
-    Args:
-        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
-
-    Returns:
-        torch.Tensor: Swizzled tensor with the same shape as input.
-
-    Raises:
-        AssertionError: If input dtype is not uint8 or bfloat16.
-    """
-    # TODO(shuw): check input dtype is uint8
-    assert (
-        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
-    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
-
--
-@flashinfer_api
-def e2m1_and_ufp8sf_scale_to_float(
-    e2m1_tensor: torch.Tensor,
-    ufp8_scale_tensor: torch.Tensor,
-    global_scale_tensor: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    ufp8_type: int = 1,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
-
-    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
-    back to float values using the associated UFP8 scale factors and global scale.
-
-    Args:
-        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
-        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
-        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
-@flashinfer_api
-def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
-    """
-    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
-    """
-    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
-
-    return input_tensor[row_indices.to(input_tensor.device)]
-
-
-@flashinfer_api
-def shuffle_matrix_sf_a(
-    input_tensor: torch.Tensor,
-    epilogue_tile_m: int,
-    num_elts_per_sf: int = 16,
-):
-    """
-    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
-    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
-    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
-    layout.
-    This function expects the input to be in linear layout. It's done this
-    way because the scaling factors in the NVFP4 checkpoints are quantized
-    and are in linear layout.
-    This function doesn't add padding.
-    """
-
-    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
-
-    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
-
--
-@flashinfer_api
-def nvfp4_quantize(
-    a,
-    a_global_sf,
-    sfLayout=SfLayout.layout_128x4,
-    do_shuffle=False,
-    sf_vec_size=16,
-    enable_pdl=None,
-):
-    """
-    Quantize input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
-        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-
--
-@flashinfer_api
-def mxfp4_quantize(a):
-    """
-    Quantize input tensor to MXFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-            - Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-    """
-    a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
-    a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
-    return a_fp4, a_sf
-
-
-@flashinfer_api
-def mxfp4_dequantize(a_fp4, a_sf):
-    """
-    Dequantize input tensor from MXFP4 format.
-
-    Parameters:
-        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    return e2m1_and_ufp8sf_scale_to_float(
-        a_fp4.cpu().view(torch.uint8),
-        a_sf.cpu().view(torch.uint8).reshape(-1),
-        torch.tensor([1.0], device=a_fp4.device),
-        32,
-        0,
-        True,
-    )
-
--
-@flashinfer_api
-def mxfp4_dequantize_host(
-    weight: torch.Tensor,
-    scale: torch.Tensor,
-    group_size: int = 32,
-) -> torch.Tensor:
-    """
-    Dequantize input tensor from MXFP4 format on host.
-
-    Parameters:
-        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-        group_size (int, optional): Group size for dequantization. Defaults to 32.
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
-    major, minor = get_compute_capability(
-        torch.device("cuda:0")
-    )  # use any cuda device to get a compute capability
--
-@flashinfer_api
-def nvfp4_batched_quantize(
-    a,
-    a_global_sf,
-    sf_vec_size=16,
-):
-    """
-    Quantize batched input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
--
-@flashinfer_api
-def scaled_fp4_grouped_quantize(
-    a,
-    mask,
-    a_global_sf,
-):
-    """
-    quantize batched input tensor to NVFP4 format with mask.
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        mask (torch.Tensor): Mask tensor to apply before quantization.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
-    a_fp4, a_sf = get_fp4_quantization_module(
-        device_arch
--
-@flashinfer_api
-def mxfp8_quantize(
-    input: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-    alignment: int = 32,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to MxFP8 format.
-
-    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
-        alignment (int, optional): sfVecSize. Defaults to 32.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
--
-@flashinfer_api
-def mxfp8_dequantize_host(
-    input: torch.Tensor,
-    scale_tensor: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Dequantize input tensor from MxFP8 format.
-
-    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
-    back to float values using the associated scale factors.
-
-    Args:
-        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
-        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
-
-    Returns:
-        torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
-
-    """
-
--
-@flashinfer_api
 def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     vectorized_f32: bool = True,
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads.

@@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     major, minor = get_compute_capability(a.device)
     if major != 10:
         raise ValueError(
-            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). "
+            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). "
             f"Got SM{major}{minor}."
         )

--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (128, 128),
-    cluster_shape_mn: Tuple[int, int] = (1, 1),
-    sm_count: Optional[int] = None,
-) -> torch.Tensor:
-    """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization.
-
--
-@flashinfer_api
 def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     cluster_shape_mn: Tuple[int, int] = (2, 1),
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads.

@@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
             expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
         token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
         out: Optional output tensor, shape (seq_len, n). Created if None.
-             This tensor is used for atomic accumulation, so it should be zero-initialized.
+             This tensor is used for atomic accumulation. If `out` is
+             provided, it must already be zero-initialized by the caller.
+             If `out` is None, this function allocates a zero-initialized
+             output tensor. Passing a non-zeroed `out` buffer will silently
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    out_scale: Optional[torch.Tensor] = None,
-    global_scale: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (256, 128),
-    cluster_shape_mn: Tuple[int, int] = (2, 1),
-    vectorized_f32: bool = True,
-    sm_count: Optional[int] = None,
--
     @flashinfer_api
     def __init__(
         self,
@@ -347,6 +355,7 @@ class CuteDslMoEWrapper:
         sf_vec_size: int = 16,
         output_dtype: torch.dtype = torch.bfloat16,
         device: str = "cuda",
+        enable_pdl: bool = True,
     ):
         """Initialize the MoE wrapper.

@@ -363,6 +372,7 @@ class CuteDslMoEWrapper:
             sf_vec_size: Scale factor vector size. Default: 16.
             output_dtype: Output data type. Default: torch.bfloat16.
             device: Device for buffer allocation. Default: "cuda".
+            enable_pdl: Enable Programmatic Dependent Launch. Default: True.
         """
         self.num_experts = num_experts
         self.top_k = top_k
@@ -376,6 +386,7 @@ class CuteDslMoEWrapper:
         self.sf_vec_size = sf_vec_size
--
     @flashinfer_api
@@ -550,9 +570,10 @@ class CuteDslMoEWrapper:
                 f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
             )

-        # Allocate output buffer if not using pre-allocated one
+        # Slice the pre-allocated buffer to the active batch so that
+        # _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
         if self.use_cuda_graph:
-            moe_output = self._moe_output
+            moe_output = self._moe_output[:num_tokens]
         else:
             moe_output = torch.empty(
                 (num_tokens, self.hidden_size),
@@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Internal implementation called by auto-tuner for functional API."""
--
 @flashinfer_api
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
@@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Run fused MoE computation using CuteDSL NVFP4 kernels.

+    Supported architectures: SM100, SM103.
+
     This is the simple functional API. For CUDA graph support, use
     `CuteDslMoEWrapper` instead.

@@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4(
         local_expert_offset=local_expert_offset,
         use_fused_finalize=use_fused_finalize,
         output_dtype=output_dtype,
+        enable_pdl=enable_pdl,
--
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
@@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose(
         - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
           and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
           (supports both the direct ``state`` path and the pool+indices path).
-        - pool+indices (``initial_state``/``initial_state_indices``) only supported
-          via the bf16 fast path; float32 state raises an error.
+        - pool+indices (``initial_state``/``initial_state_indices``) supported on
+          both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
+          (T=1). The float32 path also supports negative indices for padding.
         - Legacy path (float32 state, T=1): K and V must be multiples of 4.
     """
     # Validate input shapes
@@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose(
         return_state = initial_state if use_pool else state
         return output, return_state

-    # Legacy path: T=1 only, float32 state (no pool+indices support)
-    assert not use_pool, (
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2427,7 +489,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: Optional[bool] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
@@ -2463,8 +525,15 @@ def gated_delta_rule_mtp(
         intermediate_states_buffer (Optional[torch.Tensor]):
             Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``.
             If None, intermediate states are not cached.
-        disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``True``.
+        disable_state_update (Optional[bool]):
+            If True, the initial state is not updated. Currently defaults to ``True``.
+            Please pass this argument explicitly — the default will change to ``False``
--
 @flashinfer_api
@@ -60,16 +120,14 @@ def rmsnorm(
     output: torch.Tensor
         Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
     """
-    if enable_pdl is None:
-        enable_pdl = device_support_pdl(input.device)
     if out is None:
         out = torch.empty_like(input)
-    _rmsnorm(out, input, weight, eps, enable_pdl)
+    _rmsnorm_impl(out, input, weight, eps, enable_pdl)
     return out


 @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
-def _rmsnorm(
+def _rmsnorm_impl(
     out: torch.Tensor,
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -78,11 +136,21 @@ def _rmsnorm(
--
 @flashinfer_api
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
@@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek(
         If return_lse is False, the output will be a single tensor.
     """
     if not is_sm12x_supported(query.device):
-        major, minor = get_compute_capability(query.device)
-        if major == 12:
-            min_cuda = "13.0" if minor >= 1 else "12.8"
-            raise ValueError(
-                f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
-                f"for SM12{minor}x GPUs."
-            )
         raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
     assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
         "currently only support deepseek r1 192 query and 128 value"
     )
-    module = get_trtllm_fmha_v2_module()
+    module = get_trtllm_fmha_v2_sm120_module()
     is_e4m3 = query.dtype == torch.float8_e4m3fn
--
+@flashinfer_api
+def trtllm_fmha_v2_prefill(
+    qkv: Union[
+        torch.Tensor,
+        Tuple[torch.Tensor, torch.Tensor],
+        Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+    ],
+    input_layout: str,
+    workspace_buffer: torch.Tensor,
+    seq_lens: torch.Tensor,
+    max_q_len: int,
+    max_kv_len: int,
+    bmm1_scale: float,
+    bmm2_scale: float,
+    batch_size: int,
+    cum_seq_lens_q: torch.Tensor,
+    cum_seq_lens_kv: torch.Tensor,
+    block_tables: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    out_dtype: Optional[Union[torch.dtype, str]] = None,
+    sinks: Optional[List[torch.Tensor]] = None,
--
+@flashinfer_api
+def fp4_quantize(
+    input: torch.Tensor,
+    global_scale: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    sf_use_ue8m0: bool = False,
+    is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to FP4 format.
+
+    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
+@flashinfer_api
+def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
+    """Swizzle block scale tensor for FP4 format.
+
+    This function swizzles the block scale tensor to optimize memory access patterns
+    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
+
+    Args:
+        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
+
+    Returns:
+        torch.Tensor: Swizzled tensor with the same shape as input.
+
+    Raises:
+        AssertionError: If input dtype is not uint8 or bfloat16.
+    """
+    # TODO(shuw): check input dtype is uint8
+    assert (
+        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
+    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
+
--
+@flashinfer_api
+def e2m1_and_ufp8sf_scale_to_float(
+    e2m1_tensor: torch.Tensor,
+    ufp8_scale_tensor: torch.Tensor,
+    global_scale_tensor: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    ufp8_type: int = 1,
+    is_sf_swizzled_layout: bool = True,
+) -> torch.Tensor:
+    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
+
+    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
+    back to float values using the associated UFP8 scale factors and global scale.
+
+    Args:
+        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
+        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
+        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
+        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
+@flashinfer_api
+def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
+    """
+    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
+    """
+    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
+
+    return input_tensor[row_indices.to(input_tensor.device)]
+
+
+@flashinfer_api
+def shuffle_matrix_sf_a(
+    input_tensor: torch.Tensor,
+    epilogue_tile_m: int,
+    num_elts_per_sf: int = 16,
+):
+    """
+    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
+    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
+    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
+    layout.
+    This function expects the input to be in linear layout. It's done this
+    way because the scaling factors in the NVFP4 checkpoints are quantized
+    and are in linear layout.
+    This function doesn't add padding.
+    """
+
+    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
+
+    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
+
--
+@flashinfer_api
+def nvfp4_quantize(
+    a,
+    a_global_sf,
+    sfLayout=SfLayout.layout_128x4,
+    do_shuffle=False,
+    sf_vec_size=16,
+    enable_pdl=None,
+):
+    """
+    Quantize input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
+        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability. Defaults to None.
+
--
+@flashinfer_api
+def mxfp4_quantize(
+    a: torch.Tensor,
+    backend: str = "cuda",
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        backend (str, optional): Backend to use for quantization.
+            - "cuda": Use CUDA kernel (default, stable)
+            - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**)
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic
+            Dependent Launch). Only used when backend="cute-dsl".
+            If None, automatically detects based on device capability.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
--
+@flashinfer_api
+def mxfp4_dequantize(a_fp4, a_sf):
+    """
+    Dequantize input tensor from MXFP4 format.
+
+    Parameters:
+        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    return e2m1_and_ufp8sf_scale_to_float(
+        a_fp4.cpu().view(torch.uint8),
+        a_sf.cpu().view(torch.uint8).reshape(-1),
+        torch.tensor([1.0], device=a_fp4.device),
+        32,
+        0,
+        True,
+    )
+
--
+@flashinfer_api
+def mxfp4_dequantize_host(
+    weight: torch.Tensor,
+    scale: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """
+    Dequantize input tensor from MXFP4 format on host.
+
+    Parameters:
+        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+        group_size (int, optional): Group size for dequantization. Defaults to 32.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
+    major, minor = get_compute_capability(
+        torch.device("cuda:0")
+    )  # use any cuda device to get a compute capability
--
+@flashinfer_api
+def nvfp4_batched_quantize(
+    a,
+    a_global_sf,
+    sf_vec_size=16,
+):
+    """
+    Quantize batched input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
--
+@flashinfer_api
+def nvfp4_quantize_paged_kv_cache(
+    k_cache: torch.Tensor,
+    v_cache: torch.Tensor,
+    kv_layout: str = "HND",
+    k_global_sf: Optional[torch.Tensor] = None,
+    v_global_sf: Optional[torch.Tensor] = None,
+) -> Tuple[
+    Tuple[torch.Tensor, torch.Tensor],
+    Tuple[torch.Tensor, torch.Tensor],
+    float,
+    float,
+]:
+    """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA.
+
+    Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling
+    (global FP32 + per-block FP8), and swizzles scale factors
+    for the SM100 trtllm-gen MHA kernel layout.
+
+    Args:
+        k_cache: Key cache tensor.
--
+@flashinfer_api
+def scaled_fp4_grouped_quantize(
+    a,
+    mask,
+    a_global_sf,
+):
+    """
+    quantize batched input tensor to NVFP4 format with mask.
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        mask (torch.Tensor): Mask tensor to apply before quantization.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
+    a_fp4, a_sf = get_fp4_quantization_module(
+        device_arch
--
+@flashinfer_api
+def nvfp4_kv_dequantize(
+    fp4_data: torch.Tensor,
+    block_scales: torch.Tensor,
+    global_scale: torch.Tensor,
+    output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+    """GPU dequantization of NVFP4 KV cache data with linear block scale layout.
+
+    Requires SM80+.
+
+    Args:
+        fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+        block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]``
+            with dtype uint8.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as fp4_data.
+        output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype.
--
+@flashinfer_api
+def nvfp4_kv_quantize(
+    input: torch.Tensor,
+    global_scale: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """GPU quantization to NVFP4 KV cache format with linear block scale layout.
+
+    Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16.
+            K must be divisible by 16.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as input.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]:
+            - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+            - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8.
+    """
+    M, K = input.shape
--
+@flashinfer_api
+def mxfp8_quantize(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: Optional[bool] = None,
+    backend: Literal["cuda", "cute-dsl"] = "cuda",
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to MxFP8 format.
+
+    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        alignment (int, optional): sfVecSize. Defaults to 32.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0). Defaults to None.
+        backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are:
--
+@flashinfer_api
+def mxfp8_dequantize_host(
+    input: torch.Tensor,
+    scale_tensor: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> torch.Tensor:
+    """Dequantize input tensor from MxFP8 format.
+
+    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
+    back to float values using the associated scale factors.
+
+    Args:
+        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
+        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors.
+            If provided,it overrides is_sf_swizzled_layout. Defaults to None.
+            Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear.
+
+    Returns:
--
+@flashinfer_api
+def mxfp4_quantize_cute_dsl(
+    input: torch.Tensor,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format using CuTe-DSL kernel.
+
+    This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
+    - Global scale computed as (448 * 6) / max(|input|)
+    - UE8M0 scale factors
+    - E2M1 output format (4-bit, 2 values per byte)
+    - Swizzled (128x4) scale factor layout
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0).
--
+@flashinfer_api
+def mxfp8_quantize_cute_dsl(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP8 format using CuTe-DSL kernel.
+
+    This is a GPU implementation with dual-path optimization:
+    - LINEAR layout: SF-block based iteration (fast)
+    - SWIZZLED layout: Row-based iteration with padding fast path (optimized)
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False)
+        alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)
```


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Enhancements**
* Normalization now accepts scale as either a float or tensor; passing a
float emits a deprecation warning and is auto-converted for
compatibility.
* Attention/decoding API: cache-scale parameters are now optional
keyword-only arguments with sensible defaults, simplifying common call
patterns.
* **Tests**
* Tests updated to match the adjusted attention/decoding call signature.
* **Chores**
  * Release version bumped to 0.6.7.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-ci v0.6.7 release blocker label for 0.6.7

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants