feat(gdn): add unified decode API and deprecation shims (RFC 5.7, 5.8)#2706
feat(gdn): add unified decode API and deprecation shims (RFC 5.7, 5.8)#2706Dayuxiaoshui wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
- Add gated_delta_rule_decode_unified as single entry point with state_layout (VK/KV), state_indices, intermediate_states_buffer, disable_state_update. - Rename existing implementations to _*_impl; unified dispatches by state_layout, state dtype, and T. - Add deprecation shims for gated_delta_rule_decode_pretranspose, gated_delta_rule_decode (KV), gated_delta_rule_mtp with DeprecationWarning. - Export gated_delta_rule_decode_unified in __init__.py. - Add tests/gdn/test_gdn_decode_unified.py: cross-check vs _*_impl, intermediate_states_buffer, edge cases (pool_size=1, B=1), error paths.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request streamlines the Gated Delta Rule (GDN) decode functionality by introducing a unified API. This change centralizes the dispatch logic for different GDN decode backends, improving maintainability and future extensibility. Concurrently, it gracefully transitions users away from older, specialized decode functions by providing deprecation shims, ensuring a smooth migration path while consolidating the API surface. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds a unified Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant UnifiedDecode as gated_delta_rule_decode
participant ValidateState as _check_state_indices_bounds
participant VKBackend as _gated_delta_rule_decode_pretranspose_impl
participant KVBackend as _gated_delta_rule_decode_kv_impl
participant MTPBackend as _gated_delta_rule_mtp_impl
Caller->>UnifiedDecode: call(state_layout, state, state_indices, tokens, ...)
UnifiedDecode->>ValidateState: validate state_indices (if provided)
ValidateState-->>UnifiedDecode: ok / raise ValueError
alt state_layout == "KV" OR tokens.T == 1 (KV path)
UnifiedDecode->>KVBackend: dispatch to KV impl
KVBackend-->>UnifiedDecode: return output
else state_layout == "VK" AND bf16 AND pool AND tokens.T > 1 (MTP)
UnifiedDecode->>MTPBackend: dispatch to MTP impl
MTPBackend-->>UnifiedDecode: return output
else state_layout == "VK"
UnifiedDecode->>VKBackend: dispatch to VK/pretranspose impl
VKBackend-->>UnifiedDecode: return output
end
UnifiedDecode-->>Caller: final decoded result or error
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new unified API gated_delta_rule_decode_unified for Gated Delta Rule (GDN) decode operations, simplifying the API surface by renaming old backend implementations, adding deprecation shims, and including comprehensive tests. However, a security audit identified critical issues related to insecure input validation using assert statements and missing bounds checks for user-supplied indices, which could lead to out-of-bounds memory access on the GPU. These assert statements should be replaced with explicit checks for better robustness.
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2662-2679: The unified public entrypoint
gated_delta_rule_decode_unified is missing the required backend capability
guard; wrap it with the `@backend_requirement` decorator and ensure the code calls
the helper checks (is_compute_capability_supported(cc) and
is_backend_supported()) for the SM90+ requirement so callers fail fast on
unsupported GPUs; apply the same change to the other related unified API(s)
around the other decode function(s) referenced (the functions at the diff around
the 2730–2732 region) so both use the backend guard and support helpers
consistently.
- Around line 2704-2706: state_indices can be negative (padding) but the MTP
path/kernel currently skips padded rows leaving preallocated output rows stale;
update the call site that forwards state_indices into the MTP path to either (A)
validate and reject negative indices by raising an error if any(state_indices <
0) when an output buffer is supplied, or (B) proactively zero the corresponding
output rows before launching the MTP/kernel when output is provided;
specifically, check state_indices for negatives, and if choosing (B) zero
output[mask] (where mask = state_indices < 0) prior to the mtp kernel launch so
padded rows do not retain old values. Ensure this change is applied both where
state_indices is forwarded into the MTP path and in the analogous block handling
lines around the second occurrence mentioned (the other MTP call).
- Around line 2895-2930: The shim gated_delta_rule_decode_pretranspose lost
legacy validation: restore checks so callers must provide either state (state)
or the per-step initial state pair (initial_state and initial_state_indices) but
not both, and if initial_state is provided then initial_state_indices must also
be provided (and conversely initial_state_indices must be None when
initial_state is None); add explicit ValueError(s) with clear messages before
delegating to gated_delta_rule_decode_unified to prevent bad calls (refer to
symbols initial_state, initial_state_indices, state,
gated_delta_rule_decode_pretranspose, and gated_delta_rule_decode_unified).
- Around line 2771-2815: The BF16 branch currently always forwards the state
into _gated_delta_rule_decode_pretranspose_impl (via
initial_state/initial_state_indices or state), which prevents honoring
disable_state_update and prevents filling intermediate_states_buffer; change the
BF16 branch so that if disable_state_update is True you do not pass the current
state (pass state=None and omit initial_state/initial_state_indices) and if
intermediate_states_buffer is requested do not use the unified BF16 path but
fall back to the non-bf16/FP32 decode path that supports population of the
rollback buffer (or otherwise implement buffer population), using the same
checks around state.shape/state_indices but routing to the alternative code path
instead of always calling _gated_delta_rule_decode_pretranspose_impl.
In `@tests/gdn/test_gdn_decode_unified.py`:
- Around line 215-357: The tests miss exercising the bfloat16 dispatcher path
and fail to assert intermediate buffer rollback; update the param setup so at
least one T>1 case uses state_pool and intermediate buffers with
dtype=torch.bfloat16 (so gated_delta_rule_decode_unified hits the state.dtype ==
torch.bfloat16 branch) and add an assertion comparing intermed_unified to
intermed_legacy (e.g., torch.testing.assert_close(intermed_unified,
intermed_legacy, atol=5e-3, rtol=5e-3)) in
test_unified_vk_fp32_mtp_with_intermediate_buffer_matches_mtp; ensure the same
dtype change is applied consistently for pool_unified/pool_legacy and for any
places that construct intermed_buf so the BF16 dispatcher path in
gated_delta_rule_decode_unified and rollback caching behavior are both
exercised.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ded46dee-cac3-4404-85bf-c183c9d98073
📒 Files selected for processing (3)
flashinfer/__init__.pyflashinfer/gdn_decode.pytests/gdn/test_gdn_decode_unified.py
|
@kaixih Quick note on naming: the RFC suggests the unified entry point be called gated_delta_rule_decode, but that name is already used by the existing KV-layout API. To avoid breaking callers, this PR introduces the unified API as gated_delta_rule_decode_unified and keeps the old names as deprecation shims. If you prefer to switch to gated_delta_rule_decode for the unified API (and e.g. rename the current KV-layout one or deprecate it under a different name), I can follow up with a small rename patch. Please advise. |
|
On naming: we'd recommend the following approach rather than keeping _unified in the name (which is an implementation detail callers shouldn't see):
@kahyunnam @yzh119 how do you think? |
|
@kaixih I agree with this viewpoint, and I will make corrections and resubmit. |
…n shims, and safety checks (RFC 5.7/5.8)
|
@kaixih Thanks, done. We renamed the KV path to gated_delta_rule_decode_kv, use gated_delta_rule_decode for the unified API, and kept the three shims (pretranspose, kv, mtp) as you suggested. |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (5)
flashinfer/gdn_decode.py (4)
2674-2691:⚠️ Potential issue | 🟠 MajorAdd the standard backend requirement guard to the public decode API.
This is a public SM90+ entrypoint, but unsupported devices still fall through to JIT compilation and fail late instead of getting the usual capability check up front. The same guard should stay consistent across the public shims below. As per coding guidelines, "Use
@backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 2674 - 2691, The public decoder function gated_delta_rule_decode lacks the standard backend capability guard; add the `@backend_requirement` decorator above the gated_delta_rule_decode definition and ensure it calls the module's is_compute_capability_supported(cc) and is_backend_supported() helpers so unsupported devices fail early; also add the same `@backend_requirement` to the other public SM90+ shim functions in this file to keep guards consistent across the public decode APIs.
2926-2933:⚠️ Potential issue | 🟠 MajorReject ambiguous
state+initial_statecalls in the shim.The restored validation still allows both to be passed together; the wrapper just takes the pool path and silently ignores
state. Legacy callers used to get a clear error for that ambiguous combination.Suggested guard
use_pool = initial_state is not None if use_pool != (initial_state_indices is not None): raise ValueError( "initial_state and initial_state_indices must be provided together" ) + if state is not None and initial_state is not None: + raise ValueError("state and initial_state are mutually exclusive") if state is None and initial_state is None: raise ValueError("Either state or initial_state must be provided")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 2926 - 2933, The current shim silently prefers the pooling path when both state and initial_state are passed; change the validation in the block around use_pool to explicitly reject the ambiguous combination by raising a ValueError if both state is not None and initial_state is not None (or equivalently if use_pool and state is not None). Update the checks involving use_pool, initial_state_indices, initial_state and state (the variables/functions referenced: use_pool, initial_state_indices, initial_state, state) so that: 1) passing initial_state requires initial_state_indices, 2) passing both state and initial_state raises ValueError, and 3) the existing "Either state or initial_state must be provided" behavior remains intact.
2785-2838:⚠️ Potential issue | 🟠 MajorBF16
T>1direct-state calls still ignore rollback knobs.When
state.dtype == torch.bfloat16,T>1, andstate_indices is None, this branch still forwards to_gated_delta_rule_decode_pretranspose_impl()without checkingdisable_state_updateorintermediate_states_buffer. Both arguments are silently ignored, so callers can request read-only execution or rollback caching and still get in-place mutation with no cached states.Suggested guard
if state.dtype == torch.bfloat16: + if T > 1 and ( + disable_state_update or intermediate_states_buffer is not None + ): + raise NotImplementedError( + "VK bf16 T>1 does not support disable_state_update or " + "intermediate_states_buffer yet" + ) if T not in (1, 2, 3, 4) or K != 128 or V != 128: raise ValueError( f"VK bf16 path requires T in {{1,2,3,4}} and K=V=128, got T={T}, K={K}, V={V}"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 2785 - 2838, The BF16 non-pool branch currently ignores rollback/read-only knobs: when state.dtype == torch.bfloat16 and use_pool is False (the branch that validates state.shape and calls _gated_delta_rule_decode_pretranspose_impl with state=state), add the same guard as the pool path to check if disable_state_update is True or intermediate_states_buffer is not None and raise NotImplementedError (with the same or similar message instructing to use fp32 state for MTP); ensure this check occurs before validating state.shape and before calling _gated_delta_rule_decode_pretranspose_impl so callers cannot request read-only/rollback behavior that will be silently ignored.
2662-2671:⚠️ Potential issue | 🟠 MajorNegative
state_indicesare rejected instead of treated as padding.RFC 5.7/5.8 calls out negative-index padding semantics, but this helper raises on every negative entry. The unified API therefore still can't represent padded rows, and the new tests now lock in the opposite contract.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 2662 - 2671, The helper _check_state_indices_bounds currently treats negative state_indices as errors but per RFC 5.7/5.8 negatives are padding and should be ignored; change the validation to only validate non-negative indices: if state_indices.numel() == 0 return; build a non_negative mask = state_indices >= 0 and if no non-negative entries return; compute bad = non_negative & (state_indices >= pool_size) and if bad.any() raise ValueError with the first out-of-range value from state_indices[bad]; keep references to the same symbols (_check_state_indices_bounds, state_indices, pool_size, bad) so the change is local and preserves behavior for real indices while allowing negative padding.tests/gdn/test_gdn_decode.py (1)
216-280:⚠️ Potential issue | 🟠 MajorAdd one BF16
T>1pool parity case.All MTP parity tests here still build
torch.float32state pools, sogated_delta_rule_decode()never exercises itsstate.dtype == torch.bfloat16VK branch forT>1. That leaves a distinct dispatcher/backend path unverified even though this PR adds BF16 routing there.Also applies to: 283-360
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_gdn_decode.py` around lines 216 - 280, The test file builds state_pool as torch.float32 so gated_delta_rule_decode never exercises its bfloat16 VK branch for T>1; add an additional parameterized parity case where state_pool (and pool_legacy/pool_unified) are created with dtype=torch.bfloat16 for T>1 so gated_delta_rule_decode and _gated_delta_rule_mtp_impl are both run with bfloat16 state pools; update the test (functions test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp and the similar block at lines 283-360) to include a BF16 variant (or a separate test) that constructs state_pool with dtype=torch.bfloat16 and asserts out_unified == out_legacy and pool_unified == pool_legacy with the same tolerances.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2965-2999: The gated_delta_rule_decode_kv function currently emits
a DeprecationWarning and a docstring marking it deprecated; remove the warning
and change the docstring so gated_delta_rule_decode_kv remains a stable,
supported KV-specific entrypoint while still delegating to
gated_delta_rule_decode (keep the call that passes state_layout="KV" and all
parameters intact). Locate the gated_delta_rule_decode_kv function and delete
the warnings.warn block and the “Deprecated” wording in the docstring so callers
migrating to this explicit KV API do not get deprecated warnings.
---
Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 2674-2691: The public decoder function gated_delta_rule_decode
lacks the standard backend capability guard; add the `@backend_requirement`
decorator above the gated_delta_rule_decode definition and ensure it calls the
module's is_compute_capability_supported(cc) and is_backend_supported() helpers
so unsupported devices fail early; also add the same `@backend_requirement` to the
other public SM90+ shim functions in this file to keep guards consistent across
the public decode APIs.
- Around line 2926-2933: The current shim silently prefers the pooling path when
both state and initial_state are passed; change the validation in the block
around use_pool to explicitly reject the ambiguous combination by raising a
ValueError if both state is not None and initial_state is not None (or
equivalently if use_pool and state is not None). Update the checks involving
use_pool, initial_state_indices, initial_state and state (the
variables/functions referenced: use_pool, initial_state_indices, initial_state,
state) so that: 1) passing initial_state requires initial_state_indices, 2)
passing both state and initial_state raises ValueError, and 3) the existing
"Either state or initial_state must be provided" behavior remains intact.
- Around line 2785-2838: The BF16 non-pool branch currently ignores
rollback/read-only knobs: when state.dtype == torch.bfloat16 and use_pool is
False (the branch that validates state.shape and calls
_gated_delta_rule_decode_pretranspose_impl with state=state), add the same guard
as the pool path to check if disable_state_update is True or
intermediate_states_buffer is not None and raise NotImplementedError (with the
same or similar message instructing to use fp32 state for MTP); ensure this
check occurs before validating state.shape and before calling
_gated_delta_rule_decode_pretranspose_impl so callers cannot request
read-only/rollback behavior that will be silently ignored.
- Around line 2662-2671: The helper _check_state_indices_bounds currently treats
negative state_indices as errors but per RFC 5.7/5.8 negatives are padding and
should be ignored; change the validation to only validate non-negative indices:
if state_indices.numel() == 0 return; build a non_negative mask = state_indices
>= 0 and if no non-negative entries return; compute bad = non_negative &
(state_indices >= pool_size) and if bad.any() raise ValueError with the first
out-of-range value from state_indices[bad]; keep references to the same symbols
(_check_state_indices_bounds, state_indices, pool_size, bad) so the change is
local and preserves behavior for real indices while allowing negative padding.
In `@tests/gdn/test_gdn_decode.py`:
- Around line 216-280: The test file builds state_pool as torch.float32 so
gated_delta_rule_decode never exercises its bfloat16 VK branch for T>1; add an
additional parameterized parity case where state_pool (and
pool_legacy/pool_unified) are created with dtype=torch.bfloat16 for T>1 so
gated_delta_rule_decode and _gated_delta_rule_mtp_impl are both run with
bfloat16 state pools; update the test (functions
test_gated_delta_rule_decode_vk_fp32_mtp_matches_mtp and the similar block at
lines 283-360) to include a BF16 variant (or a separate test) that constructs
state_pool with dtype=torch.bfloat16 and asserts out_unified == out_legacy and
pool_unified == pool_legacy with the same tolerances.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7258ed1b-b041-4012-9d81-69215acaecd1
📒 Files selected for processing (5)
benchmarks/bench_gdn_decode.pyflashinfer/__init__.pyflashinfer/gdn_decode.pytests/gdn/test_decode_delta_rule.pytests/gdn/test_gdn_decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/init.py
There was a problem hiding this comment.
♻️ Duplicate comments (2)
flashinfer/gdn_decode.py (2)
649-666: 🛠️ Refactor suggestion | 🟠 MajorAdd
@backend_requirementdecorator for SM90+ requirement.The unified API is documented as requiring SM90+ (line 719), but lacks the
@backend_requirementdecorator. Without it, callers on unsupported GPUs will fail late during JIT compilation rather than getting a clear capability check upfront.As per coding guidelines: "Use
@backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 649 - 666, The gated_delta_rule_decode API lacks the required SM90+ capability check decorator; add `@backend_requirement` with the SM90+ requirement above the gated_delta_rule_decode definition so callers get an upfront capability check. Ensure the decorator uses the existing helper functions is_compute_capability_supported(cc) and is_backend_supported() to express the SM90+ requirement (matching the unified API docs) so the GPU capability/ backend support is validated before JIT compilation.
940-974:⚠️ Potential issue | 🟠 MajorRemove deprecation warning from
gated_delta_rule_decode_kv.Per the PR discussion, reviewer kaixih recommended keeping
gated_delta_rule_decode_kvas the stable explicit KV-layout entrypoint. This shim was introduced as the new name for the legacy KV path, so deprecating it immediately contradicts the migration guidance. Keep it as a thin delegation wrapper without the warning.🔧 Suggested fix
`@flashinfer_api` def gated_delta_rule_decode_kv( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, state: torch.Tensor, A_log: torch.Tensor, a: torch.Tensor, dt_bias: torch.Tensor, b: torch.Tensor, scale: Optional[float] = None, output: Optional[torch.Tensor] = None, use_qk_l2norm: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Deprecated: use gated_delta_rule_decode(..., state_layout=\"KV\") instead.""" - warnings.warn( - "gated_delta_rule_decode_kv is deprecated and will be removed in a future " - "version. Use gated_delta_rule_decode(..., state_layout='KV') instead.", - DeprecationWarning, - stacklevel=2, - ) + """KV-layout decode API. Delegates to gated_delta_rule_decode with state_layout='KV'.""" return gated_delta_rule_decode(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 940 - 974, The gated_delta_rule_decode_kv wrapper currently emits a DeprecationWarning and its docstring marks it deprecated; remove the warnings.warn call and the DeprecationWarning/stacklevel arguments and update the docstring to reflect that this is the stable KV-layout entrypoint, leaving the function body to simply delegate to gated_delta_rule_decode(..., state_layout="KV") with the same parameters (q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm).
🧹 Nitpick comments (2)
flashinfer/gdn_decode.py (2)
30-31: Unusedfunctoolsimport.
functoolsis imported but not used anywhere in the file. Remove it if not needed.🧹 Suggested fix
-import functools import warnings🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 30 - 31, Remove the unused import of functools from the top of gdn_decode.py: locate the import statement "import functools" and delete it (ensure no references to functools exist elsewhere in functions or classes in this file, e.g., any usage in functions or decorators); after removal, run the linter/tests to confirm no unresolved references remain.
722-723: Unused variableH.
His unpacked fromq.shapebut never used in the function body. Consider using_to indicate it's intentionally ignored.🧹 Suggested fix
- B, T, H, K = q.shape + B, T, _, K = q.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 722 - 723, The tuple unpacking of q.shape assigns H but H is never used; change the unpack to ignore that dimension (e.g., B, T, _, K = q.shape) so the unused variable is explicit, and ensure no other code relies on H (update any references if present); leave the v unpack (_, _, HV, V = v.shape) as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 649-666: The gated_delta_rule_decode API lacks the required SM90+
capability check decorator; add `@backend_requirement` with the SM90+ requirement
above the gated_delta_rule_decode definition so callers get an upfront
capability check. Ensure the decorator uses the existing helper functions
is_compute_capability_supported(cc) and is_backend_supported() to express the
SM90+ requirement (matching the unified API docs) so the GPU capability/ backend
support is validated before JIT compilation.
- Around line 940-974: The gated_delta_rule_decode_kv wrapper currently emits a
DeprecationWarning and its docstring marks it deprecated; remove the
warnings.warn call and the DeprecationWarning/stacklevel arguments and update
the docstring to reflect that this is the stable KV-layout entrypoint, leaving
the function body to simply delegate to gated_delta_rule_decode(...,
state_layout="KV") with the same parameters (q, k, v, state, A_log, a, dt_bias,
b, scale, output, use_qk_l2norm).
---
Nitpick comments:
In `@flashinfer/gdn_decode.py`:
- Around line 30-31: Remove the unused import of functools from the top of
gdn_decode.py: locate the import statement "import functools" and delete it
(ensure no references to functools exist elsewhere in functions or classes in
this file, e.g., any usage in functions or decorators); after removal, run the
linter/tests to confirm no unresolved references remain.
- Around line 722-723: The tuple unpacking of q.shape assigns H but H is never
used; change the unpack to ignore that dimension (e.g., B, T, _, K = q.shape) so
the unused variable is explicit, and ensure no other code relies on H (update
any references if present); leave the v unpack (_, _, HV, V = v.shape) as-is.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2de30a05-590d-4b0c-b214-d6829960d2bf
📒 Files selected for processing (3)
benchmarks/bench_gdn_decode.pyflashinfer/gdn_decode.pytests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
- benchmarks/bench_gdn_decode.py
- tests/gdn/test_decode_delta_rule.py
|
cc @kaixih |
| first_bad = state_indices[bad].flatten()[0].item() | ||
| raise ValueError( | ||
| f"state_indices must be in [0, pool_size={pool_size}); got out-of-range value {first_bad}" | ||
| ) |
There was a problem hiding this comment.
Do we need this bounds check? Two concerns: (1) the range validation adds overhead that may not be acceptable on the hot path — we generally rely on the framework/caller to ensure indices are valid. (2) More importantly, negative values are intentional: they represent padding (dummy) entries during CUDA graph capture and should be supported, not rejected.
|
|
||
|
|
||
| @flashinfer_api | ||
| def gated_delta_rule_decode( |
There was a problem hiding this comment.
maybe adding @backend_requirement for the sm90+ guard.
| ) | ||
| if T == 1: | ||
| if use_pool: | ||
| raise NotImplementedError( |
There was a problem hiding this comment.
can we double check if this is supported? from existing code, I can see:
run_pretranspose_decode(
h0_source,
A_log,
a,
dt_bias,
q,
k,
v,
b,
output,
B,
T,
H,
HV,
K,
V,
scale,
use_qk_l2norm,
use_pool_indexing=use_pool_indexing,
initial_state_indices=initial_state_indices,
)
📌 Description
Implements RFC 5.7/5.8: add unified GDN decode API
gated_delta_rule_decode_unified(state_layout, state_indices, intermediate_states_buffer, disable_state_update), rename backends to_*_implwith dispatch by layout/dtype/T, add deprecation shims for the three legacy decode APIs, export in init.py, and add tests/gdn/test_gdn_decode_unified.py (cross-check, intermediate_buffer, edge cases, error paths).🔍 Related Issues
Closes #2687
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Deprecations
Improvements