Skip to content

feat(gdn): add unified decode API and deprecation shims (RFC 5.7, 5.8)#2706

Open
Dayuxiaoshui wants to merge 3 commits intoflashinfer-ai:mainfrom
Dayuxiaoshui:main
Open

feat(gdn): add unified decode API and deprecation shims (RFC 5.7, 5.8)#2706
Dayuxiaoshui wants to merge 3 commits intoflashinfer-ai:mainfrom
Dayuxiaoshui:main

Conversation

@Dayuxiaoshui
Copy link

@Dayuxiaoshui Dayuxiaoshui commented Mar 6, 2026

📌 Description

Implements RFC 5.7/5.8: add unified GDN decode API gated_delta_rule_decode_unified (state_layout, state_indices, intermediate_states_buffer, disable_state_update), rename backends to _*_impl with dispatch by layout/dtype/T, add deprecation shims for the three legacy decode APIs, export in init.py, and add tests/gdn/test_gdn_decode_unified.py (cross-check, intermediate_buffer, edge cases, error paths).

🔍 Related Issues

Closes #2687

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a unified decode API that automatically selects the appropriate backend based on state layout and runtime configuration.
  • Deprecations

    • Backend-specific decode variants are deprecated in favor of the unified API; legacy variants remain with deprecation warnings.
  • Improvements

    • Stronger validation and clearer error messages for state/layout/parameter combinations.
    • Expanded tests and benchmarks to validate parity across implementations.

- Add gated_delta_rule_decode_unified as single entry point with state_layout (VK/KV), state_indices, intermediate_states_buffer, disable_state_update.

- Rename existing implementations to _*_impl; unified dispatches by state_layout, state dtype, and T.

- Add deprecation shims for gated_delta_rule_decode_pretranspose, gated_delta_rule_decode (KV), gated_delta_rule_mtp with DeprecationWarning.

- Export gated_delta_rule_decode_unified in __init__.py.

- Add tests/gdn/test_gdn_decode_unified.py: cross-check vs _*_impl, intermediate_states_buffer, edge cases (pool_size=1, B=1), error paths.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request streamlines the Gated Delta Rule (GDN) decode functionality by introducing a unified API. This change centralizes the dispatch logic for different GDN decode backends, improving maintainability and future extensibility. Concurrently, it gracefully transitions users away from older, specialized decode functions by providing deprecation shims, ensuring a smooth migration path while consolidating the API surface.

Highlights

  • Unified GDN Decode API: Introduced a new unified GDN decode API, gated_delta_rule_decode_unified, to serve as a single entry point for various decode operations (T=1 and MTP).
  • Backend Renaming: Renamed existing GDN decode backend functions (e.g., gated_delta_rule_decode_pretranspose) to internal _*_impl versions to support the new unified dispatch mechanism.
  • Deprecation Shims: Implemented deprecation shims for the three legacy GDN decode APIs, guiding users to the new unified API and issuing DeprecationWarnings.
  • New Test Coverage: Added a new test file, test_gdn_decode_unified.py, to thoroughly validate the functionality of the unified API, including cross-checking with legacy behavior, intermediate buffer handling, edge cases, and error paths.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/init.py
    • Exported the new gated_delta_rule_decode_unified function.
  • flashinfer/gdn_decode.py
    • Imported the warnings module for deprecation notices.
    • Renamed gated_delta_rule_decode_pretranspose to _gated_delta_rule_decode_pretranspose_impl.
    • Renamed gated_delta_rule_decode to _gated_delta_rule_decode_kv_impl.
    • Renamed gated_delta_rule_mtp to _gated_delta_rule_mtp_impl.
    • Added the gated_delta_rule_decode_unified function, which dispatches to the appropriate _impl backend based on parameters.
    • Re-added gated_delta_rule_decode_pretranspose, gated_delta_rule_decode, and gated_delta_rule_mtp as deprecation shims that call the unified API and issue warnings.
  • tests/gdn/test_gdn_decode_unified.py
    • Added a new test file to verify the gated_delta_rule_decode_unified API.
    • Included tests to ensure the unified API matches the behavior of the now-internal _impl functions for various configurations (VK bf16 T=1, VK fp32 T=1, KV fp32 T=1, VK fp32 MTP).
    • Added tests for intermediate state buffer handling in MTP.
    • Included tests for edge cases like batch size 1 and pool size 1.
    • Added tests for error handling of unsupported state_layout values, state_indices with KV layout, and VK fp32 T>1 without a pool.
Activity
  • The pull request introduces a new unified API for GDN decode operations, along with deprecation shims for older functions.
  • Existing backend implementations were refactored and renamed to internal functions.
  • Comprehensive tests were added to validate the new unified API's correctness and error handling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

Adds a unified gated_delta_rule_decode dispatcher in flashinfer.gdn_decode that routes VK/KV/MTP backends based on layout, dtype, and sequence length; renames previous public backend functions to internal impl variants and provides deprecation shims. Exposes a top-level import and updates benchmarks and tests to use backend-specific variant(s).

Changes

Cohort / File(s) Summary
Unified Decode API
flashinfer/gdn_decode.py
Added new public dispatcher gated_delta_rule_decode(...) with routing logic for VK/KV/MTP, dtype and shape validations, _check_state_indices_bounds helper, renamed prior public backends to internal _*_impl variants, and added deprecation shim wrappers that emit DeprecationWarning.
Package Exports
flashinfer/__init__.py
Exposed the new symbol via package initializer by importing gated_delta_rule_decode_unified from flashinfer.gdn_decode inside the CuteDSL MoE conditional block.
Benchmarks
benchmarks/bench_gdn_decode.py
Replaced uses/imports of gated_delta_rule_decode with the KV-specific variant gated_delta_rule_decode_kv for nontranspose/KV benchmarking paths.
Tests — single test update
tests/gdn/test_decode_delta_rule.py
Swapped import and call site to gated_delta_rule_decode_kv for the nontranspose test path.
Tests — new coverage
tests/gdn/test_gdn_decode.py
Added comprehensive tests validating parity and error handling across unified dispatcher and VK/KV/MTP paths, dtype variants, pool/indexing behavior, and deprecation shims.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant UnifiedDecode as gated_delta_rule_decode
    participant ValidateState as _check_state_indices_bounds
    participant VKBackend as _gated_delta_rule_decode_pretranspose_impl
    participant KVBackend as _gated_delta_rule_decode_kv_impl
    participant MTPBackend as _gated_delta_rule_mtp_impl

    Caller->>UnifiedDecode: call(state_layout, state, state_indices, tokens, ...)
    UnifiedDecode->>ValidateState: validate state_indices (if provided)
    ValidateState-->>UnifiedDecode: ok / raise ValueError

    alt state_layout == "KV" OR tokens.T == 1 (KV path)
        UnifiedDecode->>KVBackend: dispatch to KV impl
        KVBackend-->>UnifiedDecode: return output
    else state_layout == "VK" AND bf16 AND pool AND tokens.T > 1 (MTP)
        UnifiedDecode->>MTPBackend: dispatch to MTP impl
        MTPBackend-->>UnifiedDecode: return output
    else state_layout == "VK"
        UnifiedDecode->>VKBackend: dispatch to VK/pretranspose impl
        VKBackend-->>UnifiedDecode: return output
    end

    UnifiedDecode-->>Caller: final decoded result or error
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

v0.6.2, model: qwen3-next, ready

Suggested reviewers

  • yzh119
  • bkryu
  • cyx-6
  • nvmbreughe
  • kahyunnam
  • jimmyzho
  • jiahanc
  • kaixih

Poem

🐰 I hopped through code with eager paws,
Unified decode obeys new laws,
VK, KV, MTP — all in a row,
Old names wink and say "we'll go",
Rollback ready, seeds set to sow. 🌱

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately captures the main changes: adding a unified decode API with deprecation shims and corresponding RFC implementation.
Description check ✅ Passed The description includes all key sections (Description, Related Issues, Checklist) but the Pre-commit and Tests checkboxes are not marked as complete.
Linked Issues check ✅ Passed The PR implements core unified decode API dispatch logic, deprecation shims, and test coverage for VK/KV/MTP layouts and error cases, fulfilling primary RFC 5.7/5.8 objectives.
Out of Scope Changes check ✅ Passed All changes align with RFC 5.7/5.8 scope: unified API, dispatch routing, deprecation shims, and comprehensive test coverage for the decode layer.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new unified API gated_delta_rule_decode_unified for Gated Delta Rule (GDN) decode operations, simplifying the API surface by renaming old backend implementations, adding deprecation shims, and including comprehensive tests. However, a security audit identified critical issues related to insecure input validation using assert statements and missing bounds checks for user-supplied indices, which could lead to out-of-bounds memory access on the GPU. These assert statements should be replaced with explicit checks for better robustness.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2662-2679: The unified public entrypoint
gated_delta_rule_decode_unified is missing the required backend capability
guard; wrap it with the `@backend_requirement` decorator and ensure the code calls
the helper checks (is_compute_capability_supported(cc) and
is_backend_supported()) for the SM90+ requirement so callers fail fast on
unsupported GPUs; apply the same change to the other related unified API(s)
around the other decode function(s) referenced (the functions at the diff around
the 2730–2732 region) so both use the backend guard and support helpers
consistently.
- Around line 2704-2706: state_indices can be negative (padding) but the MTP
path/kernel currently skips padded rows leaving preallocated output rows stale;
update the call site that forwards state_indices into the MTP path to either (A)
validate and reject negative indices by raising an error if any(state_indices <
0) when an output buffer is supplied, or (B) proactively zero the corresponding
output rows before launching the MTP/kernel when output is provided;
specifically, check state_indices for negatives, and if choosing (B) zero
output[mask] (where mask = state_indices < 0) prior to the mtp kernel launch so
padded rows do not retain old values. Ensure this change is applied both where
state_indices is forwarded into the MTP path and in the analogous block handling
lines around the second occurrence mentioned (the other MTP call).
- Around line 2895-2930: The shim gated_delta_rule_decode_pretranspose lost
legacy validation: restore checks so callers must provide either state (state)
or the per-step initial state pair (initial_state and initial_state_indices) but
not both, and if initial_state is provided then initial_state_indices must also
be provided (and conversely initial_state_indices must be None when
initial_state is None); add explicit ValueError(s) with clear messages before
delegating to gated_delta_rule_decode_unified to prevent bad calls (refer to
symbols initial_state, initial_state_indices, state,
gated_delta_rule_decode_pretranspose, and gated_delta_rule_decode_unified).
- Around line 2771-2815: The BF16 branch currently always forwards the state
into _gated_delta_rule_decode_pretranspose_impl (via
initial_state/initial_state_indices or state), which prevents honoring
disable_state_update and prevents filling intermediate_states_buffer; change the
BF16 branch so that if disable_state_update is True you do not pass the current
state (pass state=None and omit initial_state/initial_state_indices) and if
intermediate_states_buffer is requested do not use the unified BF16 path but
fall back to the non-bf16/FP32 decode path that supports population of the
rollback buffer (or otherwise implement buffer population), using the same
checks around state.shape/state_indices but routing to the alternative code path
instead of always calling _gated_delta_rule_decode_pretranspose_impl.

In `@tests/gdn/test_gdn_decode_unified.py`:
- Around line 215-357: The tests miss exercising the bfloat16 dispatcher path
and fail to assert intermediate buffer rollback; update the param setup so at
least one T>1 case uses state_pool and intermediate buffers with
dtype=torch.bfloat16 (so gated_delta_rule_decode_unified hits the state.dtype ==
torch.bfloat16 branch) and add an assertion comparing intermed_unified to
intermed_legacy (e.g., torch.testing.assert_close(intermed_unified,
intermed_legacy, atol=5e-3, rtol=5e-3)) in
test_unified_vk_fp32_mtp_with_intermediate_buffer_matches_mtp; ensure the same
dtype change is applied consistently for pool_unified/pool_legacy and for any
places that construct intermed_buf so the BF16 dispatcher path in
gated_delta_rule_decode_unified and rollback caching behavior are both
exercised.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ded46dee-cac3-4404-85bf-c183c9d98073

📥 Commits

Reviewing files that changed from the base of the PR and between 124a2d3 and 772f477.

📒 Files selected for processing (3)
  • flashinfer/__init__.py
  • flashinfer/gdn_decode.py
  • tests/gdn/test_gdn_decode_unified.py

@Dayuxiaoshui
Copy link
Author

@kaixih Quick note on naming: the RFC suggests the unified entry point be called gated_delta_rule_decode, but that name is already used by the existing KV-layout API. To avoid breaking callers, this PR introduces the unified API as gated_delta_rule_decode_unified and keeps the old names as deprecation shims. If you prefer to switch to gated_delta_rule_decode for the unified API (and e.g. rename the current KV-layout one or deprecate it under a different name), I can follow up with a small rename patch. Please advise.

@kaixih
Copy link
Collaborator

kaixih commented Mar 7, 2026

On naming: we'd recommend the following approach rather than keeping _unified in the name (which is an implementation detail callers shouldn't see):

  1. Rename the existing gated_delta_rule_decode (KV layout) → gated_delta_rule_decode_kv: this is descriptive, clearly signals it's the legacy KV-layout-specific path, and the function is not currently exported from flashinfer/init.py so the blast radius is small.
  2. Use gated_delta_rule_decode for the unified API: the clean name with no suffix, as the RFC originally proposed.
  3. Keep the old gated_delta_rule_decode_pretranspose and gated_delta_rule_mtp as deprecation shims (as you've already done), and add gated_delta_rule_decode_kv as a shim for the renamed KV function.

@kahyunnam @yzh119 how do you think?

@kaixih kaixih mentioned this pull request Mar 7, 2026
40 tasks
@Dayuxiaoshui
Copy link
Author

@kaixih I agree with this viewpoint, and I will make corrections and resubmit.

@Dayuxiaoshui
Copy link
Author

@kaixih Thanks, done. We renamed the KV path to gated_delta_rule_decode_kv, use gated_delta_rule_decode for the unified API, and kept the three shims (pretranspose, kv, mtp) as you suggested.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
flashinfer/gdn_decode.py (4)

2674-2691: ⚠️ Potential issue | 🟠 Major

Add the standard backend requirement guard to the public decode API.

This is a public SM90+ entrypoint, but unsupported devices still fall through to JIT compilation and fail late instead of getting the usual capability check up front. The same guard should stay consistent across the public shims below. As per coding guidelines, "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2674 - 2691, The public decoder
function gated_delta_rule_decode lacks the standard backend capability guard;
add the `@backend_requirement` decorator above the gated_delta_rule_decode
definition and ensure it calls the module's is_compute_capability_supported(cc)
and is_backend_supported() helpers so unsupported devices fail early; also add
the same `@backend_requirement` to the other public SM90+ shim functions in this
file to keep guards consistent across the public decode APIs.

2926-2933: ⚠️ Potential issue | 🟠 Major

Reject ambiguous state + initial_state calls in the shim.

The restored validation still allows both to be passed together; the wrapper just takes the pool path and silently ignores state. Legacy callers used to get a clear error for that ambiguous combination.

Suggested guard
     use_pool = initial_state is not None
     if use_pool != (initial_state_indices is not None):
         raise ValueError(
             "initial_state and initial_state_indices must be provided together"
         )
+    if state is not None and initial_state is not None:
+        raise ValueError("state and initial_state are mutually exclusive")
     if state is None and initial_state is None:
         raise ValueError("Either state or initial_state must be provided")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2926 - 2933, The current shim silently
prefers the pooling path when both state and initial_state are passed; change
the validation in the block around use_pool to explicitly reject the ambiguous
combination by raising a ValueError if both state is not None and initial_state
is not None (or equivalently if use_pool and state is not None). Update the
checks involving use_pool, initial_state_indices, initial_state and state (the
variables/functions referenced: use_pool, initial_state_indices, initial_state,
state) so that: 1) passing initial_state requires initial_state_indices, 2)
passing both state and initial_state raises ValueError, and 3) the existing
"Either state or initial_state must be provided" behavior remains intact.

2785-2838: ⚠️ Potential issue | 🟠 Major

BF16 T>1 direct-state calls still ignore rollback knobs.

When state.dtype == torch.bfloat16, T>1, and state_indices is None, this branch still forwards to _gated_delta_rule_decode_pretranspose_impl() without checking disable_state_update or intermediate_states_buffer. Both arguments are silently ignored, so callers can request read-only execution or rollback caching and still get in-place mutation with no cached states.

Suggested guard
     if state.dtype == torch.bfloat16:
+        if T > 1 and (
+            disable_state_update or intermediate_states_buffer is not None
+        ):
+            raise NotImplementedError(
+                "VK bf16 T>1 does not support disable_state_update or "
+                "intermediate_states_buffer yet"
+            )
         if T not in (1, 2, 3, 4) or K != 128 or V != 128:
             raise ValueError(
                 f"VK bf16 path requires T in {{1,2,3,4}} and K=V=128, got T={T}, K={K}, V={V}"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2785 - 2838, The BF16 non-pool branch
currently ignores rollback/read-only knobs: when state.dtype == torch.bfloat16
and use_pool is False (the branch that validates state.shape and calls
_gated_delta_rule_decode_pretranspose_impl with state=state), add the same guard
as the pool path to check if disable_state_update is True or
intermediate_states_buffer is not None and raise NotImplementedError (with the
same or similar message instructing to use fp32 state for MTP); ensure this
check occurs before validating state.shape and before calling
_gated_delta_rule_decode_pretranspose_impl so callers cannot request
read-only/rollback behavior that will be silently ignored.

2662-2671: ⚠️ Potential issue | 🟠 Major

Negative state_indices are rejected instead of treated as padding.

RFC 5.7/5.8 calls out negative-index padding semantics, but this helper raises on every negative entry. The unified API therefore still can't represent padded rows, and the new tests now lock in the opposite contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 2662 - 2671, The helper
_check_state_indices_bounds currently treats negative state_indices as errors
but per RFC 5.7/5.8 negatives are padding and should be ignored; change the
validation to only validate non-negative indices: if state_indices.numel() == 0
return; build a non_negative mask = state_indices >= 0 and if no non-negative
entries return; compute bad = non_negative & (state_indices >= pool_size) and if
bad.any() raise ValueError with the first out-of-range value from
state_indices[bad]; keep references to the same symbols
(_check_state_indices_bounds, state_indices, pool_size, bad) so the change is
local and preserves behavior for real indices while allowing negative padding.
tests/gdn/test_gdn_decode.py (1)

216-280: ⚠️ Potential issue | 🟠 Major

Add one BF16 T>1 pool parity case.

All MTP parity tests here still build torch.float32 state pools, so gated_delta_rule_decode() never exercises its state.dtype == torch.bfloat16 VK branch for T>1. That leaves a distinct dispatcher/backend path unverified even though this PR adds BF16 routing there.

Also applies to: 283-360

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_gdn_decode.py` around lines 216 - 280, The test file builds
state_pool as torch.float32 so gated_delta_rule_decode never exercises its
bfloat16 VK branch for T>1; add an additional parameterized parity case where
state_pool (and pool_legacy/pool_unified) are created with dtype=torch.bfloat16
for T>1 so gated_delta_rule_decode and _gated_delta_rule_mtp_impl are both run
with bfloat16 state pools; update the test (functions
test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp and the similar block at
lines 283-360) to include a BF16 variant (or a separate test) that constructs
state_pool with dtype=torch.bfloat16 and asserts out_unified == out_legacy and
pool_unified == pool_legacy with the same tolerances.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2965-2999: The gated_delta_rule_decode_kv function currently emits
a DeprecationWarning and a docstring marking it deprecated; remove the warning
and change the docstring so gated_delta_rule_decode_kv remains a stable,
supported KV-specific entrypoint while still delegating to
gated_delta_rule_decode (keep the call that passes state_layout="KV" and all
parameters intact). Locate the gated_delta_rule_decode_kv function and delete
the warnings.warn block and the “Deprecated” wording in the docstring so callers
migrating to this explicit KV API do not get deprecated warnings.

---

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2674-2691: The public decoder function gated_delta_rule_decode
lacks the standard backend capability guard; add the `@backend_requirement`
decorator above the gated_delta_rule_decode definition and ensure it calls the
module's is_compute_capability_supported(cc) and is_backend_supported() helpers
so unsupported devices fail early; also add the same `@backend_requirement` to the
other public SM90+ shim functions in this file to keep guards consistent across
the public decode APIs.
- Around line 2926-2933: The current shim silently prefers the pooling path when
both state and initial_state are passed; change the validation in the block
around use_pool to explicitly reject the ambiguous combination by raising a
ValueError if both state is not None and initial_state is not None (or
equivalently if use_pool and state is not None). Update the checks involving
use_pool, initial_state_indices, initial_state and state (the
variables/functions referenced: use_pool, initial_state_indices, initial_state,
state) so that: 1) passing initial_state requires initial_state_indices, 2)
passing both state and initial_state raises ValueError, and 3) the existing
"Either state or initial_state must be provided" behavior remains intact.
- Around line 2785-2838: The BF16 non-pool branch currently ignores
rollback/read-only knobs: when state.dtype == torch.bfloat16 and use_pool is
False (the branch that validates state.shape and calls
_gated_delta_rule_decode_pretranspose_impl with state=state), add the same guard
as the pool path to check if disable_state_update is True or
intermediate_states_buffer is not None and raise NotImplementedError (with the
same or similar message instructing to use fp32 state for MTP); ensure this
check occurs before validating state.shape and before calling
_gated_delta_rule_decode_pretranspose_impl so callers cannot request
read-only/rollback behavior that will be silently ignored.
- Around line 2662-2671: The helper _check_state_indices_bounds currently treats
negative state_indices as errors but per RFC 5.7/5.8 negatives are padding and
should be ignored; change the validation to only validate non-negative indices:
if state_indices.numel() == 0 return; build a non_negative mask = state_indices
>= 0 and if no non-negative entries return; compute bad = non_negative &
(state_indices >= pool_size) and if bad.any() raise ValueError with the first
out-of-range value from state_indices[bad]; keep references to the same symbols
(_check_state_indices_bounds, state_indices, pool_size, bad) so the change is
local and preserves behavior for real indices while allowing negative padding.

In `@tests/gdn/test_gdn_decode.py`:
- Around line 216-280: The test file builds state_pool as torch.float32 so
gated_delta_rule_decode never exercises its bfloat16 VK branch for T>1; add an
additional parameterized parity case where state_pool (and
pool_legacy/pool_unified) are created with dtype=torch.bfloat16 for T>1 so
gated_delta_rule_decode and _gated_delta_rule_mtp_impl are both run with
bfloat16 state pools; update the test (functions
test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp and the similar block at
lines 283-360) to include a BF16 variant (or a separate test) that constructs
state_pool with dtype=torch.bfloat16 and asserts out_unified == out_legacy and
pool_unified == pool_legacy with the same tolerances.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7258ed1b-b041-4012-9d81-69215acaecd1

📥 Commits

Reviewing files that changed from the base of the PR and between 772f477 and 1a37574.

📒 Files selected for processing (5)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/__init__.py
  • flashinfer/gdn_decode.py
  • tests/gdn/test_decode_delta_rule.py
  • tests/gdn/test_gdn_decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/init.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
flashinfer/gdn_decode.py (2)

649-666: 🛠️ Refactor suggestion | 🟠 Major

Add @backend_requirement decorator for SM90+ requirement.

The unified API is documented as requiring SM90+ (line 719), but lacks the @backend_requirement decorator. Without it, callers on unsupported GPUs will fail late during JIT compilation rather than getting a clear capability check upfront.

As per coding guidelines: "Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 649 - 666, The gated_delta_rule_decode
API lacks the required SM90+ capability check decorator; add
`@backend_requirement` with the SM90+ requirement above the
gated_delta_rule_decode definition so callers get an upfront capability check.
Ensure the decorator uses the existing helper functions
is_compute_capability_supported(cc) and is_backend_supported() to express the
SM90+ requirement (matching the unified API docs) so the GPU capability/ backend
support is validated before JIT compilation.

940-974: ⚠️ Potential issue | 🟠 Major

Remove deprecation warning from gated_delta_rule_decode_kv.

Per the PR discussion, reviewer kaixih recommended keeping gated_delta_rule_decode_kv as the stable explicit KV-layout entrypoint. This shim was introduced as the new name for the legacy KV path, so deprecating it immediately contradicts the migration guidance. Keep it as a thin delegation wrapper without the warning.

🔧 Suggested fix
 `@flashinfer_api`
 def gated_delta_rule_decode_kv(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     state: torch.Tensor,
     A_log: torch.Tensor,
     a: torch.Tensor,
     dt_bias: torch.Tensor,
     b: torch.Tensor,
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Deprecated: use gated_delta_rule_decode(..., state_layout=\"KV\") instead."""
-    warnings.warn(
-        "gated_delta_rule_decode_kv is deprecated and will be removed in a future "
-        "version. Use gated_delta_rule_decode(..., state_layout='KV') instead.",
-        DeprecationWarning,
-        stacklevel=2,
-    )
+    """KV-layout decode API. Delegates to gated_delta_rule_decode with state_layout='KV'."""
     return gated_delta_rule_decode(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 940 - 974, The
gated_delta_rule_decode_kv wrapper currently emits a DeprecationWarning and its
docstring marks it deprecated; remove the warnings.warn call and the
DeprecationWarning/stacklevel arguments and update the docstring to reflect that
this is the stable KV-layout entrypoint, leaving the function body to simply
delegate to gated_delta_rule_decode(..., state_layout="KV") with the same
parameters (q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm).
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)

30-31: Unused functools import.

functools is imported but not used anywhere in the file. Remove it if not needed.

🧹 Suggested fix
-import functools
 import warnings
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 30 - 31, Remove the unused import of
functools from the top of gdn_decode.py: locate the import statement "import
functools" and delete it (ensure no references to functools exist elsewhere in
functions or classes in this file, e.g., any usage in functions or decorators);
after removal, run the linter/tests to confirm no unresolved references remain.

722-723: Unused variable H.

H is unpacked from q.shape but never used in the function body. Consider using _ to indicate it's intentionally ignored.

🧹 Suggested fix
-    B, T, H, K = q.shape
+    B, T, _, K = q.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 722 - 723, The tuple unpacking of
q.shape assigns H but H is never used; change the unpack to ignore that
dimension (e.g., B, T, _, K = q.shape) so the unused variable is explicit, and
ensure no other code relies on H (update any references if present); leave the v
unpack (_, _, HV, V = v.shape) as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 649-666: The gated_delta_rule_decode API lacks the required SM90+
capability check decorator; add `@backend_requirement` with the SM90+ requirement
above the gated_delta_rule_decode definition so callers get an upfront
capability check. Ensure the decorator uses the existing helper functions
is_compute_capability_supported(cc) and is_backend_supported() to express the
SM90+ requirement (matching the unified API docs) so the GPU capability/ backend
support is validated before JIT compilation.
- Around line 940-974: The gated_delta_rule_decode_kv wrapper currently emits a
DeprecationWarning and its docstring marks it deprecated; remove the
warnings.warn call and the DeprecationWarning/stacklevel arguments and update
the docstring to reflect that this is the stable KV-layout entrypoint, leaving
the function body to simply delegate to gated_delta_rule_decode(...,
state_layout="KV") with the same parameters (q, k, v, state, A_log, a, dt_bias,
b, scale, output, use_qk_l2norm).

---

Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 30-31: Remove the unused import of functools from the top of
gdn_decode.py: locate the import statement "import functools" and delete it
(ensure no references to functools exist elsewhere in functions or classes in
this file, e.g., any usage in functions or decorators); after removal, run the
linter/tests to confirm no unresolved references remain.
- Around line 722-723: The tuple unpacking of q.shape assigns H but H is never
used; change the unpack to ignore that dimension (e.g., B, T, _, K = q.shape) so
the unused variable is explicit, and ensure no other code relies on H (update
any references if present); leave the v unpack (_, _, HV, V = v.shape) as-is.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2de30a05-590d-4b0c-b214-d6829960d2bf

📥 Commits

Reviewing files that changed from the base of the PR and between 1a37574 and 1a75db1.

📒 Files selected for processing (3)
  • benchmarks/bench_gdn_decode.py
  • flashinfer/gdn_decode.py
  • tests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • benchmarks/bench_gdn_decode.py
  • tests/gdn/test_decode_delta_rule.py

@Dayuxiaoshui
Copy link
Author

cc @kaixih

first_bad = state_indices[bad].flatten()[0].item()
raise ValueError(
f"state_indices must be in [0, pool_size={pool_size}); got out-of-range value {first_bad}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this bounds check? Two concerns: (1) the range validation adds overhead that may not be acceptable on the hot path — we generally rely on the framework/caller to ensure indices are valid. (2) More importantly, negative values are intentional: they represent padding (dummy) entries during CUDA graph capture and should be supported, not rejected.



@flashinfer_api
def gated_delta_rule_decode(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe adding @backend_requirement for the sm90+ guard.

)
if T == 1:
if use_pool:
raise NotImplementedError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we double check if this is supported? from existing code, I can see:

 run_pretranspose_decode(
        h0_source,
        A_log,
        a,
        dt_bias,
        q,
        k,
        v,
        b,
        output,
        B,
        T,
        H,
        HV,
        K,
        V,
        scale,
        use_qk_l2norm,
        use_pool_indexing=use_pool_indexing,
        initial_state_indices=initial_state_indices,
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Unified GDN Decode/Prefill API

2 participants