Skip to content

feat: add pool+indices support to gated_delta_rule_decode_pretranspose (bf16 path) #2619

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
kaixih:gdn-pool-indices
Mar 4, 2026
Merged

feat: add pool+indices support to gated_delta_rule_decode_pretranspose (bf16 path) #2619
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
kaixih:gdn-pool-indices

Conversation

@kaixih
Copy link
Collaborator

@kaixih kaixih commented Feb 22, 2026

Summary

  • Adds initial_state=[pool, HV, V, K] and initial_state_indices=[B] parameters
    to gated_delta_rule_decode_pretranspose, allowing callers to pass an SSM state
    pool + per-batch indices directly instead of doing gather/scatter around the kernel.
  • The kernel reads and writes the pool in-place, so no separate scatter step is needed
    after the call.
  • Only supported via the bf16 fast path (gdn_decode_klast_bf16_state): bfloat16
    state, T in 1..4, K=V=128.
Callsite Before:
state = ssm_states[cache_indices]   # gather
out, state = kernel(state=state, ...)
ssm_states[cache_indices] = state   # scatter

Callsite After:
out, _ = kernel(initial_state=ssm_states, initial_state_indices=cache_indices, ...)
# ssm_states updated in-place, no scatter needed

Changes

  • gdn_decode.py: add initial_state/initial_state_indices params; route pool
    path through the bf16 kernel
  • gdn_decode_bf16_state.py: add gH0_indices to all 3 kernels and 5 launch
    functions; gated_delta_rule now uses identity indices when
    initial_state_indices=None
  • tests/gdn/test_decode_delta_rule.py: add pool+indices test against
    gather→direct-state reference

Summary by CodeRabbit

  • New Features
    • Optional external initial-state pool with per-batch indexing for GDN decode, plus an optimized BF16 fast path for pool-based state to improve inference performance; backward compatible (parameters are optional).
  • Bug Fixes
    • Stronger validation and gating to ensure correct path selection and in-place state updates.
  • Documentation
    • Clarified docstrings describing pool layout and index behavior.
  • Tests
    • Added tests covering the pool+indices pretranspose decode path and in-place update consistency.

…e (bf16 path)

Allows callers to pass initial_state=[pool,HV,V,K] + initial_state_indices=[B]
directly, avoiding expensive gather/transpose/scatter around the kernel call.
The kernel reads/writes the pool in-place via per-batch index lookup.

Only supported via the bf16 fast path (gdn_decode_klast_bf16_state):
bfloat16 state, T in 1..4, K=V=128. Float32 legacy path raises an error.

Changes:
- gdn_decode.py: add initial_state/initial_state_indices params to
  gated_delta_rule_decode_pretranspose; route pool path through bf16 kernel
- gdn_decode_bf16_state.py: add gH0_indices param to all 3 kernels and
  5 launch functions; gated_delta_rule now creates identity indices when
  initial_state_indices is None and includes pool_size in cache key
- tests: update test_decode_kernel_pretranspose_pool to use bfloat16 state

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 22, 2026

📝 Walkthrough

Walkthrough

Adds optional pool-based initial-state support to GDN decode kernels and Python wrappers with per-batch initial_state_indices. Kernels and dispatch now route either a bf16 pool fast-path (updates in-place into the pool) or a legacy in-place state path, preserving backward compatibility.

Changes

Cohort / File(s) Summary
GDN Decode API Layer
flashinfer/gdn_decode.py
Expanded public APIs to accept initial_state: Optional[Tensor] and initial_state_indices: Optional[Tensor], made state optional, added validation/dispatch to choose pool (bf16) vs legacy path, updated kernel launches and return semantics (returns pool tensor when pool path used).
BF16 Kernel Implementations
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Threaded per-batch pool-slot index tensors (gH_slot_indices/mH_slot_indices) through kernel signatures and launches, replaced batch-indexed H lookups with pool-mapped lookups, added pool_size to cache keys, and enabled in-kernel in-place updates into the pool.
Tests
tests/gdn/test_decode_delta_rule.py
Added tests _test_decode_kernel_pretranspose_pool and test_decode_kernel_pretranspose_pool to validate pool+indices pretranspose path, in-place pool updates, and equivalence with direct-state reference execution.
Misc / Cleanup
flashinfer/gdn_decode.py, flashinfer/gdn_kernels/...
Removed obsolete memory-size lines, silenced debug prints, tightened dtype/shape gating for bf16 pool path vs float32 legacy path and adjusted documentation strings.

Sequence Diagram(s)

sequenceDiagram
    participant API as Python API
    participant Dispatch as Validation / Dispatch
    participant Pool as Initial-State Pool
    participant Legacy as Legacy State
    participant Kernel as GPU Kernel
    participant Caller as Caller (returns)

    API->>Dispatch: call gated_delta_rule_decode(..., initial_state?, initial_state_indices?)
    Dispatch->>Dispatch: validate shapes, dtype, K/V constraints
    alt Pool path (initial_state provided, bf16, K=V=128, T in 1..4)
        Dispatch->>Pool: read initial_state & indices
        Dispatch->>Kernel: launch bf16 pool-accelerated kernel with indices
        Kernel->>Pool: per-batch pool_batch_idx = indices[batch]
        Kernel->>Pool: update pool slots in-place
        Kernel-->>Caller: return initial_state (pool)
    else Legacy path
        Dispatch->>Legacy: use provided state tensor (or in-place path)
        Dispatch->>Kernel: launch legacy float32 kernel (no indices)
        Kernel->>Legacy: update state in-place
        Kernel-->>Caller: return state (legacy)
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

v0.6.2, model: qwen3-next

Suggested reviewers

  • cyx-6
  • bkryu
  • nvmbreughe
  • kahyunnam
  • jimmyzho
  • yzh119

Poem

🐰 In pools of bits my state shall hide,

Indices hop to match each stride,
bf16 wings carry the flight,
Legacy keeps the old path tight,
In-place updates — a rabbit's pride.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: adding pool+indices support to the gated_delta_rule_decode_pretranspose function for the bf16 fast path, which is the primary feature introduced in this PR.
Description check ✅ Passed The PR description includes a clear summary explaining the changes, provides before/after code examples, lists specific files modified (gdn_decode.py, gdn_decode_bf16_state.py, tests/gdn/test_decode_delta_rule.py), and describes implementation details. However, it does not follow the provided template structure (missing the 'Related Issues' section and explicit checklist marks, though pre-commit and test information are implicitly covered).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes state management within the gated_delta_rule_decode_pretranspose function by enabling direct, indexed access to a state pool. This enhancement streamlines the process of handling SSM states in batched inference scenarios, reducing boilerplate code and potential overhead associated with explicit gather and scatter operations. The change primarily benefits the bfloat16 fast path, providing a more efficient mechanism for state updates during single-token generation.

Highlights

  • Direct State Pool Access: Introduced initial_state and initial_state_indices parameters to gated_delta_rule_decode_pretranspose, allowing callers to pass a state pool and per-batch indices directly to the kernel. This eliminates the need for manual gather/scatter operations around the kernel call.
  • In-place State Updates: The kernel now reads from and writes to the provided state pool in-place, simplifying state management for the caller as no separate scatter step is required after the kernel execution.
  • BF16 Fast Path Support: This new pool+indices functionality is specifically supported via the bfloat16 fast path (gdn_decode_klast_bf16_state) for T in 1..4 and K=V=128 configurations, with explicit assertions preventing its use with the legacy float32 path.
  • Kernel Argument Updates: Modified the underlying CUDA kernels and launch functions (gdn_decode_bf16_state.py) to accept and utilize the gH0_indices for indexed access to the state pool, ensuring correct state retrieval and updates.
  • Comprehensive Testing: Added a new test suite that rigorously verifies the correctness of the pool+indices path by comparing its output and state updates against a gather-run-scatter reference implementation, including checks for untouched pool slots.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gdn_decode.py
    • Removed unused debug print statements from kernel run functions.
    • Modified gated_delta_rule_decode_pretranspose to accept optional initial_state and initial_state_indices parameters.
    • Updated docstrings for gated_delta_rule_decode_pretranspose to reflect new parameters and in-place state update behavior.
    • Added validation logic to ensure initial_state and initial_state_indices are used together and have correct shapes.
    • Adjusted the state_dtype determination to account for the new initial_state parameter.
    • Updated the call to the gdn_decode_klast_bf16_state backend to pass the appropriate state source and indices.
    • Modified the return value of gated_delta_rule_decode_pretranspose to return initial_state when the pool path is used.
    • Added an assertion to prevent the use of the pool+indices path with the legacy float32 state path.
    • Reordered state copy logic for the legacy path and updated the compiled kernel call syntax.
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
    • Added gH0_indices as a parameter to gated_delta_rule_decode_kernel_seqlen1, gated_delta_rule_decode_kernel_seqlen234_unified, and gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk.
    • Introduced pool_batch_idx calculation using gH0_indices to map batch indices to state pool indices.
    • Updated global memory accesses for state (gH) within kernels to use pool_batch_idx for indexed lookup.
    • Modified launch functions (gated_delta_rule_launch_seqlen1, gated_delta_rule_launch_seqlen1_lowBS_1chunk, gated_delta_rule_launch_seqlen2, gated_delta_rule_launch_seqlen3, gated_delta_rule_launch_seqlen4) to pass mH0_indices to the respective kernels.
    • Updated the docstring for gated_delta_rule to clarify the usage of initial_state_source and initial_state_indices.
    • Implemented logic in gated_delta_rule to generate identity indices if initial_state_indices is None and ensure indices are int32.
    • Added h0_indices_ to the arguments passed to the compiled kernel functions.
    • Included pool_size in the cache key for compiled kernels to differentiate between direct and pooled state paths.
  • tests/gdn/test_decode_delta_rule.py
    • Added a new test function _test_decode_kernel_pretranspose_pool to verify the pool+indices path.
    • Implemented a parametrized test test_decode_kernel_pretranspose_pool to cover various batch sizes and configurations.
    • The new test compares outputs and state updates of the pool+indices path against a gather-run-scatter reference.
    • Included assertions to ensure selected pool slots match updated states and non-selected slots remain unchanged.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 22, 2026

cc @hlu1

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully introduces pool+indices support to the gated_delta_rule_decode_pretranspose function, specifically targeting the bfloat16 fast path. This enhancement allows for more efficient state management by enabling in-place updates within a state pool, thereby eliminating the need for external gather/scatter operations. The implementation correctly routes the new parameters through the backend kernels and includes comprehensive tests to verify correctness and in-place behavior. My feedback focuses on ensuring device consistency for the newly introduced indices and optimizing the allocation of identity indices in the hot path.

Comment on lines +2011 to +2016
if initial_state_indices is None:
h0_indices = torch.arange(B, dtype=torch.int32, device=q.device)
elif initial_state_indices.dtype != torch.int32:
h0_indices = initial_state_indices.to(torch.int32)
else:
h0_indices = initial_state_indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two points for improvement here:

  1. Device Consistency: When initial_state_indices is provided by the user, it should be explicitly moved to the same device as the input tensors (e.g., q.device). If the user passes indices residing on the CPU, the current implementation will pass a CPU pointer to the CUDA kernel via from_dlpack, which will cause a runtime crash during kernel execution.
  2. Performance Optimization: For the direct path (where initial_state_indices is None), torch.arange is called on every invocation. While fast for small batch sizes, it still introduces a small allocation overhead in the hot path. Since the kernel is already specialized for batch size B in the compilation cache, it would be more efficient to cache these identity indices or reuse a pre-allocated buffer.
    if initial_state_indices is None:
        h0_indices = torch.arange(B, dtype=torch.int32, device=q.device)
    else:
        h0_indices = initial_state_indices.to(device=q.device, dtype=torch.int32)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (1)

511-535: Pool+indices test only exercises T=1; T=2/3/4 not covered

The feature is described as supporting T in 1..4, but _test_decode_kernel_pretranspose_pool hardcodes q = torch.randn(batch_size, 1, ...) with no seq_len parameter. The two distinct kernel paths for T=1 (low-BS 1-chunk vs. standard) are exercised, but gated_delta_rule_decode_kernel_seqlen234_unified with the new gH0_indices argument is never exercised by a pool+indices test.

Consider adding a seq_len: int parameter (defaulting to 1) and parametrizing:

+@pytest.mark.parametrize("seq_len", [1, 2, 3, 4])
 `@pytest.mark.parametrize`("scale", [1.0])
 ...
 def test_decode_kernel_pretranspose_pool(
     ...
+    seq_len: int,
     seed: int = int(os.environ.get("SEED", "0")),
 ):
     _test_decode_kernel_pretranspose_pool(
-        dtype, batch_size, ..., seed=seed,
+        dtype, batch_size, ..., seq_len=seq_len, seed=seed,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 511 - 535, The pool+indices
test only uses seq_len=1 because _test_decode_kernel_pretranspose_pool currently
hardcodes q = torch.randn(batch_size, 1, ...); add a seq_len: int parameter
(default 1) to _test_decode_kernel_pretranspose_pool and to the test wrapper
test_decode_kernel_pretranspose_pool, use that seq_len when creating q so tests
can exercise T=2/3/4, and parametrize the test to include seq_len values 1,2,3,4
so the gated_delta_rule_decode_kernel_seqlen234_unified path and the gH0_indices
argument are exercised.
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)

2011-2016: Per-call torch.arange(B) allocation each decode step when initial_state_indices=None

The identity-indices tensor is recreated on every invocation. The analogous nontranspose path in gdn_decode.py (line 1934) avoids this by caching the tensor in the compiled-kernel cache dict. At high decode throughput this becomes repeated allocation/deallocation overhead.

♻️ Suggested approach (mirroring the nontranspose pattern)
+    if "h0_indices" not in _compiled_kernels.get(cache_key, {}):
+        _id_indices = torch.arange(B, dtype=torch.int32, device=q.device)
+    else:
+        _id_indices = _compiled_kernels[cache_key].get("h0_indices")
+
     # Resolve indices: identity mapping when not provided
     if initial_state_indices is None:
-        h0_indices = torch.arange(B, dtype=torch.int32, device=q.device)
+        h0_indices = _id_indices
     elif initial_state_indices.dtype != torch.int32:
         h0_indices = initial_state_indices.to(torch.int32)
     else:
         h0_indices = initial_state_indices

Alternatively, cache the identity tensor in _compiled_kernels[cache_key] after the first compile (similar to how cu_seqlens is cached in the pretranspose path).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2011 - 2016,
The code recreates the identity indices tensor each call when
initial_state_indices is None; change gdn_decode_bf16_state.py so that if
initial_state_indices is None you first check _compiled_kernels[cache_key] for a
cached identity tensor and use that, and if missing create torch.arange(B,
dtype=torch.int32, device=q.device), store it in _compiled_kernels[cache_key]
(same cache slot used for cu_seqlens in the pretranspose path), and then assign
it to h0_indices; keep the existing branches for when initial_state_indices is
provided and ensure dtype remains torch.int32.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1017-1022: The code validates initial_state shape but never checks
values in initial_state_indices; add a bounds check (e.g., compute min and max
of initial_state_indices and compare against 0 and pool_size) and raise a clear
error (IndexError or ValueError) if any index is out of range before using
initial_state_indices in use_pool path; reference the variables
initial_state_indices, initial_state, pool_size and the use_pool branch and
prefer using assert for hot-path avoidance if you need this check to be compiled
out (or document the requirement in the function docstring).

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 406-419: The helper _test_decode_kernel_pretranspose_pool must
early-skip when the native bf16-state kernel is unavailable: add the same
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE guard used by
_test_gdn_decode_klast_bf16_state_kernel and
test_pretranspose_api_uses_gdn_decode_klast_bf16_state (e.g., check
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE at the top of
_test_decode_kernel_pretranspose_pool and call pytest.skip with a clear message
before calling _skip_if_not_sm90_or_later()), so
gated_delta_rule_decode_pretranspose never hits the legacy-path assert on
machines without the kernel.

---

Nitpick comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2011-2016: The code recreates the identity indices tensor each
call when initial_state_indices is None; change gdn_decode_bf16_state.py so that
if initial_state_indices is None you first check _compiled_kernels[cache_key]
for a cached identity tensor and use that, and if missing create torch.arange(B,
dtype=torch.int32, device=q.device), store it in _compiled_kernels[cache_key]
(same cache slot used for cu_seqlens in the pretranspose path), and then assign
it to h0_indices; keep the existing branches for when initial_state_indices is
provided and ensure dtype remains torch.int32.

In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 511-535: The pool+indices test only uses seq_len=1 because
_test_decode_kernel_pretranspose_pool currently hardcodes q =
torch.randn(batch_size, 1, ...); add a seq_len: int parameter (default 1) to
_test_decode_kernel_pretranspose_pool and to the test wrapper
test_decode_kernel_pretranspose_pool, use that seq_len when creating q so tests
can exercise T=2/3/4, and parametrize the test to include seq_len values 1,2,3,4
so the gated_delta_rule_decode_kernel_seqlen234_unified path and the gH0_indices
argument are exercised.

Comment on lines +1017 to +1022
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

No bounds check on initial_state_indices values — out-of-bounds indices cause silent GPU memory corruption

initial_state.shape is validated, but individual index values in initial_state_indices are never checked against [0, pool_size). An out-of-range index produces a CUDA illegal-access fault or, worse, silently overwrites an adjacent allocation without any Python-level diagnostic.

🛡️ Proposed fix
 if use_pool:
     pool_size = initial_state.shape[0]
     assert initial_state.shape == (pool_size, HV, V, K), (
         f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
         f"got {initial_state.shape}"
     )
+    assert initial_state_indices.shape == (B,), (
+        f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}"
+    )
+    assert (
+        int(initial_state_indices.min()) >= 0
+        and int(initial_state_indices.max()) < pool_size
+    ), (
+        f"initial_state_indices values must be in [0, {pool_size}), "
+        f"got [{int(initial_state_indices.min())}, {int(initial_state_indices.max())}]"
+    )

Note: .min()/.max() trigger a host sync; if that's unacceptable on a hot path, guard this behind assert statements that can be compiled-out, or document the constraint clearly in the docstring.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
assert initial_state_indices.shape == (B,), (
f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}"
)
assert (
int(initial_state_indices.min()) >= 0
and int(initial_state_indices.max()) < pool_size
), (
f"initial_state_indices values must be in [0, {pool_size}), "
f"got [{int(initial_state_indices.min())}, {int(initial_state_indices.max())}]"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1017 - 1022, The code validates
initial_state shape but never checks values in initial_state_indices; add a
bounds check (e.g., compute min and max of initial_state_indices and compare
against 0 and pool_size) and raise a clear error (IndexError or ValueError) if
any index is out of range before using initial_state_indices in use_pool path;
reference the variables initial_state_indices, initial_state, pool_size and the
use_pool branch and prefer using assert for hot-path avoidance if you need this
check to be compiled out (or document the requirement in the function
docstring).

Comment on lines +406 to +419
def _test_decode_kernel_pretranspose_pool(
dtype: str,
batch_size: int,
num_q_heads: int,
num_k_heads: int,
num_v_heads: int,
head_size: int,
scale: float,
pool_multiplier: int = 3,
state_dtype: str = "bfloat16",
seed: int | None = None,
):
"""Pool+indices path must match gather → direct-state → scatter reference."""
_skip_if_not_sm90_or_later()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing GDN_DECODE_KLAST_BF16_STATE_AVAILABLE guard — test will hard-fail instead of skip

Every test that exercises the bf16-state kernel path (_test_gdn_decode_klast_bf16_state_kernel, test_pretranspose_api_uses_gdn_decode_klast_bf16_state) checks this flag before proceeding. This new helper omits it, so on any machine where gdn_decode_klast_bf16_state cannot be imported, gated_delta_rule_decode_pretranspose hits assert not use_pool in the legacy path and raises AssertionError instead of a clean pytest.skip.

🛡️ Proposed fix
 def _test_decode_kernel_pretranspose_pool(
     ...
 ):
     """Pool+indices path must match gather → direct-state → scatter reference."""
     _skip_if_not_sm90_or_later()
+    if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE:
+        pytest.skip("gdn_decode_klast_bf16_state kernel not available")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 406 - 419, The helper
_test_decode_kernel_pretranspose_pool must early-skip when the native bf16-state
kernel is unavailable: add the same GDN_DECODE_KLAST_BF16_STATE_AVAILABLE guard
used by _test_gdn_decode_klast_bf16_state_kernel and
test_pretranspose_api_uses_gdn_decode_klast_bf16_state (e.g., check
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE at the top of
_test_decode_kernel_pretranspose_pool and call pytest.skip with a clear message
before calling _skip_if_not_sm90_or_later()), so
gated_delta_rule_decode_pretranspose never hits the legacy-path assert on
machines without the kernel.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 22, 2026

Hi @kaixih does this PR serve the same purpose as #2521?

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 23, 2026

Hi @kaixih does this PR serve the same purpose as #2521?

PR #2521 covers the float32 pretranspose path; our PR covers the bf16 fast path (gdn_decode_bf16_state.py) which has separate kernel implementations. They're complementary, not duplicate

@yzh119
Copy link
Collaborator

yzh119 commented Feb 23, 2026

Sounds good, just want to make sure the interface and semantics are aligned.

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 23, 2026

TBH, the current GDN inference APIs are a bit cluttered. We’re seeing a lot of forked paths to support different layouts (V-last vs. K-last), state dtypes (bf16 vs. fp32), and batch sizes, as well as varying MTP support. In my opinion, we should plan a refactor after this PR to converge these paths and stabilize the API.

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 24, 2026

gentle ping. thx @yzh119

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 24, 2026

Also, please let us know the plan for this. If there isn’t one yet, we’d be happy to contribute.

#2619 (comment)

@yzh119
Copy link
Collaborator

yzh119 commented Feb 24, 2026

Hi @kaixih

Also, please let us know the plan for this. If there isn’t one yet, we’d be happy to contribute.

I don't think it's planned, and think it would be great if you can work on this.

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 26, 2026

@yzh119 any updates?

…yncs

Replace the ambiguous H0 subscript notation (which implies "H at time zero")
with slot_indices to clearly convey pool slot selection. Also drop two
defensive torch.cuda.synchronize() calls in the pool test that are
unnecessary: the two kernel calls operate on separate tensors, and CUDA
stream serialization already ensures ordering.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@kaixih
Copy link
Collaborator Author

kaixih commented Mar 3, 2026

@yzh119 PTAL

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/gdn/test_decode_delta_rule.py (1)

406-420: ⚠️ Potential issue | 🟠 Major

Add bf16-kernel availability skip guard in this helper.

At Line 406-Line 420, this helper can still hard-fail when gdn_decode_klast_bf16_state is unavailable. Mirror the existing GDN_DECODE_KLAST_BF16_STATE_AVAILABLE skip pattern used elsewhere in this file before running the pool-path call.

🔧 Suggested fix
 def _test_decode_kernel_pretranspose_pool(
@@
 ):
     """Pool+indices path must match gather → direct-state → scatter reference."""
+    if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE:
+        pytest.skip("gdn_decode_klast_bf16_state kernel not available")
     _skip_if_not_sm90_or_later()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` around lines 406 - 420, The helper
_test_decode_kernel_pretranspose_pool may hard-fail if the bf16 kernel
gdn_decode_klast_bf16_state is unavailable; add the same availability guard used
elsewhere (check GDN_DECODE_KLAST_BF16_STATE_AVAILABLE) at the start of this
helper and skip the test when false, mirroring the existing pattern in this file
so the pool-path call is not executed unless the bf16 kernel is present.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2005-2016: The code normalizes dtype for initial_state_indices but
does not validate shape, device, bounds, or duplicates, which can cause
out-of-bounds or concurrent writes when kernels read gH_slot_indices; update the
initialization of h_slot_indices (when initial_state_indices is provided or
defaulted) to: 1) ensure the tensor is on q.device, 2) verify its shape matches
B (or broadcast/raise), 3) check all indices are >=0 and < pool_size, and 4)
detect duplicate indices and either reject or remap/report them before launching
kernels; perform these checks where h_slot_indices is set (related symbols:
initial_state_indices, h_slot_indices, initial_state_source, pool_size,
gH_slot_indices) and raise a clear error if validation fails.

---

Duplicate comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 406-420: The helper _test_decode_kernel_pretranspose_pool may
hard-fail if the bf16 kernel gdn_decode_klast_bf16_state is unavailable; add the
same availability guard used elsewhere (check
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE) at the start of this helper and skip the
test when false, mirroring the existing pattern in this file so the pool-path
call is not executed unless the bf16 kernel is present.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 29de39a and a6bc95d.

📒 Files selected for processing (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

Comment on lines +2005 to +2016
pool_size = initial_state_source.shape[0]

if scale is None:
scale = 1.0 / math.sqrt(K)

# Resolve indices: identity mapping when not provided
if initial_state_indices is None:
h_slot_indices = torch.arange(B, dtype=torch.int32, device=q.device)
elif initial_state_indices.dtype != torch.int32:
h_slot_indices = initial_state_indices.to(torch.int32)
else:
h_slot_indices = initial_state_indices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Validate initial_state_indices before launching kernels to prevent OOB/racy state writes.

At Line 2010-Line 2016, only dtype normalization is done. Without shape/device/range (and duplicate-slot) checks, kernel reads/writes via gH_slot_indices[batch_idx] (Line 734/Line 1145/Line 1578) can go out of bounds or have concurrent writes to the same pool slot.

🔧 Suggested fix
-    pool_size = initial_state_source.shape[0]
+    pool_size = initial_state_source.shape[0]

     # Resolve indices: identity mapping when not provided
     if initial_state_indices is None:
+        if pool_size < B:
+            raise ValueError(
+                f"initial_state_source.shape[0] ({pool_size}) must be >= batch size ({B}) "
+                "when initial_state_indices is None"
+            )
         h_slot_indices = torch.arange(B, dtype=torch.int32, device=q.device)
-    elif initial_state_indices.dtype != torch.int32:
-        h_slot_indices = initial_state_indices.to(torch.int32)
     else:
-        h_slot_indices = initial_state_indices
+        if initial_state_indices.ndim != 1 or initial_state_indices.numel() != B:
+            raise ValueError(
+                f"initial_state_indices must have shape [B]={B}, got {tuple(initial_state_indices.shape)}"
+            )
+        h_slot_indices = initial_state_indices.to(
+            device=q.device, dtype=torch.int32
+        ).contiguous()
+        if torch.any(h_slot_indices < 0) or torch.any(h_slot_indices >= pool_size):
+            raise ValueError(
+                f"initial_state_indices must be in [0, {pool_size}), got min/max="
+                f"({int(h_slot_indices.min())}, {int(h_slot_indices.max())})"
+            )
+        if torch.unique(h_slot_indices).numel() != B:
+            raise ValueError(
+                "initial_state_indices must be unique per batch entry to avoid concurrent "
+                "writes to the same pool slot"
+            )

Also applies to: 2028-2028

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2005 - 2016,
The code normalizes dtype for initial_state_indices but does not validate shape,
device, bounds, or duplicates, which can cause out-of-bounds or concurrent
writes when kernels read gH_slot_indices; update the initialization of
h_slot_indices (when initial_state_indices is provided or defaulted) to: 1)
ensure the tensor is on q.device, 2) verify its shape matches B (or
broadcast/raise), 3) check all indices are >=0 and < pool_size, and 4) detect
duplicate indices and either reject or remap/report them before launching
kernels; perform these checks where h_slot_indices is set (related symbols:
initial_state_indices, h_slot_indices, initial_state_source, pool_size,
gH_slot_indices) and raise a clear error if validation fails.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119
Copy link
Collaborator

yzh119 commented Mar 4, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !369 has been created, and the CI pipeline #45293540 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45293540: 10/20 passed

@ZJY0516
Copy link
Contributor

ZJY0516 commented Mar 4, 2026

Is Hopper supported? I'm encountering a problem.

(EngineCore_DP0 pid=1368071) ⚙️  Current Settings:
(EngineCore_DP0 pid=1368071) - CUDA Toolkit Path: /usr/local/cuda
(EngineCore_DP0 pid=1368071) - Target Architecture: sm_90a
(EngineCore_DP0 pid=1368071) 
(EngineCore_DP0 pid=1368071) IR Context (truncated):
(EngineCore_DP0 pid=1368071) "gpu.module"() <{sym_name = "kernels", targets = [#nvvm.target<O = 3, chip = "sm_90a", flags = {"ptx-cmd-options" = []}>]}> ({
(EngineCore_DP0 pid=1368071)     "llvm.mlir.global"() <{addr_space = 3 : i32, alignment = 1024 : i64, dso_local, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage<external>, sym_name = "__dynamic_shmem__0", visibility_ = 0 : i64}> ({
(EngineCore_DP0 pid=1368071)     }) : () -> ()
(EngineCore_DP0 pid=1368071)     "llvm.func"() <{CConv = #llvm.cconv<ccc>, function_type = !llvm.func<f32 (f32)>, linkage = #llvm.linkage<external>, sym_name = "__nv_rsqrtf", visibility_ = 0 : i64}> ({
(EngineCore_DP0 pid=1368071)     }) : () -> ()
(EngineCore_DP0 pid=1368071)   ...
(EngineCore_DP0 pid=1368071)       "llvm.return"() : () -> ()
(EngineCore_DP0 pid=1368071)     }) {cu_attrs = {max_dynamic_shared_size_bytes = #cuda.dev_max_shared_memory_optin, non_portable_cluster_size_allowed = 1 : i32}, gpu.kernel, nvvm.kernel, nvvm.reqntid = array<i32: 128, 1, 1>} : () -> ()
(EngineCore_DP0 pid=1368071)   }) {compute_targets = [#cuda.compute_target<sass, conditional, [sm_90]>]} : () -> ()
(EngineCore_DP0 pid=1368071) error: "<module>"("/mnt/data1/zjy/code/flashinfer/flashinfer/gdn_kernels/gdn_decode_bf16_state.py":1495:0): Failed translating the module to ISA.
(EngineCore_DP0 pid=1368071)  note: "<module>"("/mnt/data1/zjy/code/flashinfer/flashinfer/gdn_kernels/gdn_decode_bf16_state.py":1495:0):
(EngineCore_DP0 pid=1368071) 
(EngineCore_DP0 pid=1368071) 💡 Possible Solutions:
(EngineCore_DP0 pid=1368071) 1. Check if CUDA_TOOLKIT_PATH is set correctly
(EngineCore_DP0 pid=1368071) 2. Verify target architecture (sm_90a) is supported by your CUDA toolkit
(EngineCore_DP0 pid=1368071) 3. Make sure CUDA toolkit version matches the target architecture requirements
(EngineCore_DP0 pid=1368071) 
(EngineCore_DP0 pid=1368071) During handling of the above exception, another exception occurred:

@kaixih kaixih mentioned this pull request Mar 4, 2026
40 tasks
@yzh119 yzh119 merged commit eee401b into flashinfer-ai:main Mar 4, 2026
31 of 36 checks passed
xutizhou added a commit to xutizhou/flashinfer that referenced this pull request Mar 5, 2026
…om PR flashinfer-ai#2619

Resolve merge conflicts with upstream main which added pool+indices support
via the bf16 fast path (PR flashinfer-ai#2619). Key changes:
- Adopt upstream API naming: state_indices -> initial_state_indices,
  state pool passed via initial_state param
- Update test_decode_pooled.py to use new API with bf16 state
- Skip negative-index tests (bf16 kernel does not support them yet)
- Legacy f32 CuTe DSL path preserved for non-pool usage

AI-assisted merge resolution.
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…e (bf16 path) (flashinfer-ai#2619)

## Summary
- Adds `initial_state=[pool, HV, V, K]` and `initial_state_indices=[B]`
parameters
to `gated_delta_rule_decode_pretranspose`, allowing callers to pass an
SSM state
pool + per-batch indices directly instead of doing gather/scatter around
the kernel.
- The kernel reads and writes the pool in-place, so no separate scatter
step is needed
    after the call.
- Only supported via the bf16 fast path (`gdn_decode_klast_bf16_state`):
bfloat16
    state, T in 1..4, K=V=128.

  ```python
Callsite Before:
  state = ssm_states[cache_indices]   # gather
  out, state = kernel(state=state, ...)
  ssm_states[cache_indices] = state   # scatter

Callsite After:
out, _ = kernel(initial_state=ssm_states,
initial_state_indices=cache_indices, ...)
  # ssm_states updated in-place, no scatter needed
```
  Changes

  - gdn_decode.py: add initial_state/initial_state_indices params; route pool
  path through the bf16 kernel
  - gdn_decode_bf16_state.py: add gH0_indices to all 3 kernels and 5 launch
  functions; gated_delta_rule now uses identity indices when
  initial_state_indices=None
  - tests/gdn/test_decode_delta_rule.py: add pool+indices test against
  gather→direct-state reference

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Optional external initial-state pool with per-batch indexing for GDN decode, plus an optimized BF16 fast path for pool-based state to improve inference performance; backward compatible (parameters are optional).
* **Bug Fixes**
  * Stronger validation and gating to ensure correct path selection and in-place state updates.
* **Documentation**
  * Clarified docstrings describing pool layout and index behavior.
* **Tests**
  * Added tests covering the pool+indices pretranspose decode path and in-place update consistency.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants