feat: add pool+indices support to gated_delta_rule_decode_pretranspose (bf16 path) #2619
feat: add pool+indices support to gated_delta_rule_decode_pretranspose (bf16 path) #2619yzh119 merged 2 commits intoflashinfer-ai:mainfrom
Conversation
…e (bf16 path) Allows callers to pass initial_state=[pool,HV,V,K] + initial_state_indices=[B] directly, avoiding expensive gather/transpose/scatter around the kernel call. The kernel reads/writes the pool in-place via per-batch index lookup. Only supported via the bf16 fast path (gdn_decode_klast_bf16_state): bfloat16 state, T in 1..4, K=V=128. Float32 legacy path raises an error. Changes: - gdn_decode.py: add initial_state/initial_state_indices params to gated_delta_rule_decode_pretranspose; route pool path through bf16 kernel - gdn_decode_bf16_state.py: add gH0_indices param to all 3 kernels and 5 launch functions; gated_delta_rule now creates identity indices when initial_state_indices is None and includes pool_size in cache key - tests: update test_decode_kernel_pretranspose_pool to use bfloat16 state AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdds optional pool-based initial-state support to GDN decode kernels and Python wrappers with per-batch Changes
Sequence Diagram(s)sequenceDiagram
participant API as Python API
participant Dispatch as Validation / Dispatch
participant Pool as Initial-State Pool
participant Legacy as Legacy State
participant Kernel as GPU Kernel
participant Caller as Caller (returns)
API->>Dispatch: call gated_delta_rule_decode(..., initial_state?, initial_state_indices?)
Dispatch->>Dispatch: validate shapes, dtype, K/V constraints
alt Pool path (initial_state provided, bf16, K=V=128, T in 1..4)
Dispatch->>Pool: read initial_state & indices
Dispatch->>Kernel: launch bf16 pool-accelerated kernel with indices
Kernel->>Pool: per-batch pool_batch_idx = indices[batch]
Kernel->>Pool: update pool slots in-place
Kernel-->>Caller: return initial_state (pool)
else Legacy path
Dispatch->>Legacy: use provided state tensor (or in-place path)
Dispatch->>Kernel: launch legacy float32 kernel (no indices)
Kernel->>Legacy: update state in-place
Kernel-->>Caller: return state (legacy)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes state management within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
cc @hlu1 |
There was a problem hiding this comment.
Code Review
The pull request successfully introduces pool+indices support to the gated_delta_rule_decode_pretranspose function, specifically targeting the bfloat16 fast path. This enhancement allows for more efficient state management by enabling in-place updates within a state pool, thereby eliminating the need for external gather/scatter operations. The implementation correctly routes the new parameters through the backend kernels and includes comprehensive tests to verify correctness and in-place behavior. My feedback focuses on ensuring device consistency for the newly introduced indices and optimizing the allocation of identity indices in the hot path.
| if initial_state_indices is None: | ||
| h0_indices = torch.arange(B, dtype=torch.int32, device=q.device) | ||
| elif initial_state_indices.dtype != torch.int32: | ||
| h0_indices = initial_state_indices.to(torch.int32) | ||
| else: | ||
| h0_indices = initial_state_indices |
There was a problem hiding this comment.
There are two points for improvement here:
- Device Consistency: When
initial_state_indicesis provided by the user, it should be explicitly moved to the same device as the input tensors (e.g.,q.device). If the user passes indices residing on the CPU, the current implementation will pass a CPU pointer to the CUDA kernel viafrom_dlpack, which will cause a runtime crash during kernel execution. - Performance Optimization: For the direct path (where
initial_state_indicesisNone),torch.arangeis called on every invocation. While fast for small batch sizes, it still introduces a small allocation overhead in the hot path. Since the kernel is already specialized for batch sizeBin the compilation cache, it would be more efficient to cache these identity indices or reuse a pre-allocated buffer.
if initial_state_indices is None:
h0_indices = torch.arange(B, dtype=torch.int32, device=q.device)
else:
h0_indices = initial_state_indices.to(device=q.device, dtype=torch.int32)There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/gdn/test_decode_delta_rule.py (1)
511-535: Pool+indices test only exercises T=1; T=2/3/4 not coveredThe feature is described as supporting T in 1..4, but
_test_decode_kernel_pretranspose_poolhardcodesq = torch.randn(batch_size, 1, ...)with noseq_lenparameter. The two distinct kernel paths for T=1 (low-BS 1-chunk vs. standard) are exercised, butgated_delta_rule_decode_kernel_seqlen234_unifiedwith the newgH0_indicesargument is never exercised by a pool+indices test.Consider adding a
seq_len: intparameter (defaulting to 1) and parametrizing:+@pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) `@pytest.mark.parametrize`("scale", [1.0]) ... def test_decode_kernel_pretranspose_pool( ... + seq_len: int, seed: int = int(os.environ.get("SEED", "0")), ): _test_decode_kernel_pretranspose_pool( - dtype, batch_size, ..., seed=seed, + dtype, batch_size, ..., seq_len=seq_len, seed=seed, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 511 - 535, The pool+indices test only uses seq_len=1 because _test_decode_kernel_pretranspose_pool currently hardcodes q = torch.randn(batch_size, 1, ...); add a seq_len: int parameter (default 1) to _test_decode_kernel_pretranspose_pool and to the test wrapper test_decode_kernel_pretranspose_pool, use that seq_len when creating q so tests can exercise T=2/3/4, and parametrize the test to include seq_len values 1,2,3,4 so the gated_delta_rule_decode_kernel_seqlen234_unified path and the gH0_indices argument are exercised.flashinfer/gdn_kernels/gdn_decode_bf16_state.py (1)
2011-2016: Per-calltorch.arange(B)allocation each decode step wheninitial_state_indices=NoneThe identity-indices tensor is recreated on every invocation. The analogous nontranspose path in
gdn_decode.py(line 1934) avoids this by caching the tensor in the compiled-kernel cache dict. At high decode throughput this becomes repeated allocation/deallocation overhead.♻️ Suggested approach (mirroring the nontranspose pattern)
+ if "h0_indices" not in _compiled_kernels.get(cache_key, {}): + _id_indices = torch.arange(B, dtype=torch.int32, device=q.device) + else: + _id_indices = _compiled_kernels[cache_key].get("h0_indices") + # Resolve indices: identity mapping when not provided if initial_state_indices is None: - h0_indices = torch.arange(B, dtype=torch.int32, device=q.device) + h0_indices = _id_indices elif initial_state_indices.dtype != torch.int32: h0_indices = initial_state_indices.to(torch.int32) else: h0_indices = initial_state_indicesAlternatively, cache the identity tensor in
_compiled_kernels[cache_key]after the first compile (similar to howcu_seqlensis cached in the pretranspose path).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2011 - 2016, The code recreates the identity indices tensor each call when initial_state_indices is None; change gdn_decode_bf16_state.py so that if initial_state_indices is None you first check _compiled_kernels[cache_key] for a cached identity tensor and use that, and if missing create torch.arange(B, dtype=torch.int32, device=q.device), store it in _compiled_kernels[cache_key] (same cache slot used for cu_seqlens in the pretranspose path), and then assign it to h0_indices; keep the existing branches for when initial_state_indices is provided and ensure dtype remains torch.int32.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1017-1022: The code validates initial_state shape but never checks
values in initial_state_indices; add a bounds check (e.g., compute min and max
of initial_state_indices and compare against 0 and pool_size) and raise a clear
error (IndexError or ValueError) if any index is out of range before using
initial_state_indices in use_pool path; reference the variables
initial_state_indices, initial_state, pool_size and the use_pool branch and
prefer using assert for hot-path avoidance if you need this check to be compiled
out (or document the requirement in the function docstring).
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 406-419: The helper _test_decode_kernel_pretranspose_pool must
early-skip when the native bf16-state kernel is unavailable: add the same
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE guard used by
_test_gdn_decode_klast_bf16_state_kernel and
test_pretranspose_api_uses_gdn_decode_klast_bf16_state (e.g., check
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE at the top of
_test_decode_kernel_pretranspose_pool and call pytest.skip with a clear message
before calling _skip_if_not_sm90_or_later()), so
gated_delta_rule_decode_pretranspose never hits the legacy-path assert on
machines without the kernel.
---
Nitpick comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2011-2016: The code recreates the identity indices tensor each
call when initial_state_indices is None; change gdn_decode_bf16_state.py so that
if initial_state_indices is None you first check _compiled_kernels[cache_key]
for a cached identity tensor and use that, and if missing create torch.arange(B,
dtype=torch.int32, device=q.device), store it in _compiled_kernels[cache_key]
(same cache slot used for cu_seqlens in the pretranspose path), and then assign
it to h0_indices; keep the existing branches for when initial_state_indices is
provided and ensure dtype remains torch.int32.
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 511-535: The pool+indices test only uses seq_len=1 because
_test_decode_kernel_pretranspose_pool currently hardcodes q =
torch.randn(batch_size, 1, ...); add a seq_len: int parameter (default 1) to
_test_decode_kernel_pretranspose_pool and to the test wrapper
test_decode_kernel_pretranspose_pool, use that seq_len when creating q so tests
can exercise T=2/3/4, and parametrize the test to include seq_len values 1,2,3,4
so the gated_delta_rule_decode_kernel_seqlen234_unified path and the gH0_indices
argument are exercised.
| if use_pool: | ||
| pool_size = initial_state.shape[0] | ||
| assert initial_state.shape == (pool_size, HV, V, K), ( | ||
| f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], " | ||
| f"got {initial_state.shape}" | ||
| ) |
There was a problem hiding this comment.
No bounds check on initial_state_indices values — out-of-bounds indices cause silent GPU memory corruption
initial_state.shape is validated, but individual index values in initial_state_indices are never checked against [0, pool_size). An out-of-range index produces a CUDA illegal-access fault or, worse, silently overwrites an adjacent allocation without any Python-level diagnostic.
🛡️ Proposed fix
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
+ assert initial_state_indices.shape == (B,), (
+ f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}"
+ )
+ assert (
+ int(initial_state_indices.min()) >= 0
+ and int(initial_state_indices.max()) < pool_size
+ ), (
+ f"initial_state_indices values must be in [0, {pool_size}), "
+ f"got [{int(initial_state_indices.min())}, {int(initial_state_indices.max())}]"
+ )Note: .min()/.max() trigger a host sync; if that's unacceptable on a hot path, guard this behind assert statements that can be compiled-out, or document the constraint clearly in the docstring.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if use_pool: | |
| pool_size = initial_state.shape[0] | |
| assert initial_state.shape == (pool_size, HV, V, K), ( | |
| f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], " | |
| f"got {initial_state.shape}" | |
| ) | |
| if use_pool: | |
| pool_size = initial_state.shape[0] | |
| assert initial_state.shape == (pool_size, HV, V, K), ( | |
| f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], " | |
| f"got {initial_state.shape}" | |
| ) | |
| assert initial_state_indices.shape == (B,), ( | |
| f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}" | |
| ) | |
| assert ( | |
| int(initial_state_indices.min()) >= 0 | |
| and int(initial_state_indices.max()) < pool_size | |
| ), ( | |
| f"initial_state_indices values must be in [0, {pool_size}), " | |
| f"got [{int(initial_state_indices.min())}, {int(initial_state_indices.max())}]" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_decode.py` around lines 1017 - 1022, The code validates
initial_state shape but never checks values in initial_state_indices; add a
bounds check (e.g., compute min and max of initial_state_indices and compare
against 0 and pool_size) and raise a clear error (IndexError or ValueError) if
any index is out of range before using initial_state_indices in use_pool path;
reference the variables initial_state_indices, initial_state, pool_size and the
use_pool branch and prefer using assert for hot-path avoidance if you need this
check to be compiled out (or document the requirement in the function
docstring).
| def _test_decode_kernel_pretranspose_pool( | ||
| dtype: str, | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_k_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| pool_multiplier: int = 3, | ||
| state_dtype: str = "bfloat16", | ||
| seed: int | None = None, | ||
| ): | ||
| """Pool+indices path must match gather → direct-state → scatter reference.""" | ||
| _skip_if_not_sm90_or_later() |
There was a problem hiding this comment.
Missing GDN_DECODE_KLAST_BF16_STATE_AVAILABLE guard — test will hard-fail instead of skip
Every test that exercises the bf16-state kernel path (_test_gdn_decode_klast_bf16_state_kernel, test_pretranspose_api_uses_gdn_decode_klast_bf16_state) checks this flag before proceeding. This new helper omits it, so on any machine where gdn_decode_klast_bf16_state cannot be imported, gated_delta_rule_decode_pretranspose hits assert not use_pool in the legacy path and raises AssertionError instead of a clean pytest.skip.
🛡️ Proposed fix
def _test_decode_kernel_pretranspose_pool(
...
):
"""Pool+indices path must match gather → direct-state → scatter reference."""
_skip_if_not_sm90_or_later()
+ if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE:
+ pytest.skip("gdn_decode_klast_bf16_state kernel not available")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_delta_rule.py` around lines 406 - 419, The helper
_test_decode_kernel_pretranspose_pool must early-skip when the native bf16-state
kernel is unavailable: add the same GDN_DECODE_KLAST_BF16_STATE_AVAILABLE guard
used by _test_gdn_decode_klast_bf16_state_kernel and
test_pretranspose_api_uses_gdn_decode_klast_bf16_state (e.g., check
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE at the top of
_test_decode_kernel_pretranspose_pool and call pytest.skip with a clear message
before calling _skip_if_not_sm90_or_later()), so
gated_delta_rule_decode_pretranspose never hits the legacy-path assert on
machines without the kernel.
|
Sounds good, just want to make sure the interface and semantics are aligned. |
|
TBH, the current GDN inference APIs are a bit cluttered. We’re seeing a lot of forked paths to support different layouts (V-last vs. K-last), state dtypes (bf16 vs. fp32), and batch sizes, as well as varying MTP support. In my opinion, we should plan a refactor after this PR to converge these paths and stabilize the API. |
|
gentle ping. thx @yzh119 |
|
Also, please let us know the plan for this. If there isn’t one yet, we’d be happy to contribute. |
|
Hi @kaixih
I don't think it's planned, and think it would be great if you can work on this. |
|
@yzh119 any updates? |
…yncs Replace the ambiguous H0 subscript notation (which implies "H at time zero") with slot_indices to clearly convey pool slot selection. Also drop two defensive torch.cuda.synchronize() calls in the pool test that are unnecessary: the two kernel calls operate on separate tensors, and CUDA stream serialization already ensures ordering. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@yzh119 PTAL |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/gdn/test_decode_delta_rule.py (1)
406-420:⚠️ Potential issue | 🟠 MajorAdd bf16-kernel availability skip guard in this helper.
At Line 406-Line 420, this helper can still hard-fail when
gdn_decode_klast_bf16_stateis unavailable. Mirror the existingGDN_DECODE_KLAST_BF16_STATE_AVAILABLEskip pattern used elsewhere in this file before running the pool-path call.🔧 Suggested fix
def _test_decode_kernel_pretranspose_pool( @@ ): """Pool+indices path must match gather → direct-state → scatter reference.""" + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + pytest.skip("gdn_decode_klast_bf16_state kernel not available") _skip_if_not_sm90_or_later()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` around lines 406 - 420, The helper _test_decode_kernel_pretranspose_pool may hard-fail if the bf16 kernel gdn_decode_klast_bf16_state is unavailable; add the same availability guard used elsewhere (check GDN_DECODE_KLAST_BF16_STATE_AVAILABLE) at the start of this helper and skip the test when false, mirroring the existing pattern in this file so the pool-path call is not executed unless the bf16 kernel is present.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2005-2016: The code normalizes dtype for initial_state_indices but
does not validate shape, device, bounds, or duplicates, which can cause
out-of-bounds or concurrent writes when kernels read gH_slot_indices; update the
initialization of h_slot_indices (when initial_state_indices is provided or
defaulted) to: 1) ensure the tensor is on q.device, 2) verify its shape matches
B (or broadcast/raise), 3) check all indices are >=0 and < pool_size, and 4)
detect duplicate indices and either reject or remap/report them before launching
kernels; perform these checks where h_slot_indices is set (related symbols:
initial_state_indices, h_slot_indices, initial_state_source, pool_size,
gH_slot_indices) and raise a clear error if validation fails.
---
Duplicate comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 406-420: The helper _test_decode_kernel_pretranspose_pool may
hard-fail if the bf16 kernel gdn_decode_klast_bf16_state is unavailable; add the
same availability guard used elsewhere (check
GDN_DECODE_KLAST_BF16_STATE_AVAILABLE) at the start of this helper and skip the
test when false, mirroring the existing pattern in this file so the pool-path
call is not executed unless the bf16 kernel is present.
| pool_size = initial_state_source.shape[0] | ||
|
|
||
| if scale is None: | ||
| scale = 1.0 / math.sqrt(K) | ||
|
|
||
| # Resolve indices: identity mapping when not provided | ||
| if initial_state_indices is None: | ||
| h_slot_indices = torch.arange(B, dtype=torch.int32, device=q.device) | ||
| elif initial_state_indices.dtype != torch.int32: | ||
| h_slot_indices = initial_state_indices.to(torch.int32) | ||
| else: | ||
| h_slot_indices = initial_state_indices |
There was a problem hiding this comment.
Validate initial_state_indices before launching kernels to prevent OOB/racy state writes.
At Line 2010-Line 2016, only dtype normalization is done. Without shape/device/range (and duplicate-slot) checks, kernel reads/writes via gH_slot_indices[batch_idx] (Line 734/Line 1145/Line 1578) can go out of bounds or have concurrent writes to the same pool slot.
🔧 Suggested fix
- pool_size = initial_state_source.shape[0]
+ pool_size = initial_state_source.shape[0]
# Resolve indices: identity mapping when not provided
if initial_state_indices is None:
+ if pool_size < B:
+ raise ValueError(
+ f"initial_state_source.shape[0] ({pool_size}) must be >= batch size ({B}) "
+ "when initial_state_indices is None"
+ )
h_slot_indices = torch.arange(B, dtype=torch.int32, device=q.device)
- elif initial_state_indices.dtype != torch.int32:
- h_slot_indices = initial_state_indices.to(torch.int32)
else:
- h_slot_indices = initial_state_indices
+ if initial_state_indices.ndim != 1 or initial_state_indices.numel() != B:
+ raise ValueError(
+ f"initial_state_indices must have shape [B]={B}, got {tuple(initial_state_indices.shape)}"
+ )
+ h_slot_indices = initial_state_indices.to(
+ device=q.device, dtype=torch.int32
+ ).contiguous()
+ if torch.any(h_slot_indices < 0) or torch.any(h_slot_indices >= pool_size):
+ raise ValueError(
+ f"initial_state_indices must be in [0, {pool_size}), got min/max="
+ f"({int(h_slot_indices.min())}, {int(h_slot_indices.max())})"
+ )
+ if torch.unique(h_slot_indices).numel() != B:
+ raise ValueError(
+ "initial_state_indices must be unique per batch entry to avoid concurrent "
+ "writes to the same pool slot"
+ )Also applies to: 2028-2028
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2005 - 2016,
The code normalizes dtype for initial_state_indices but does not validate shape,
device, bounds, or duplicates, which can cause out-of-bounds or concurrent
writes when kernels read gH_slot_indices; update the initialization of
h_slot_indices (when initial_state_indices is provided or defaulted) to: 1)
ensure the tensor is on q.device, 2) verify its shape matches B (or
broadcast/raise), 3) check all indices are >=0 and < pool_size, and 4) detect
duplicate indices and either reject or remap/report them before launching
kernels; perform these checks where h_slot_indices is set (related symbols:
initial_state_indices, h_slot_indices, initial_state_source, pool_size,
gH_slot_indices) and raise a clear error if validation fails.
|
/bot run |
|
[FAILED] Pipeline #45293540: 10/20 passed |
|
Is Hopper supported? I'm encountering a problem. |
…om PR flashinfer-ai#2619 Resolve merge conflicts with upstream main which added pool+indices support via the bf16 fast path (PR flashinfer-ai#2619). Key changes: - Adopt upstream API naming: state_indices -> initial_state_indices, state pool passed via initial_state param - Update test_decode_pooled.py to use new API with bf16 state - Skip negative-index tests (bf16 kernel does not support them yet) - Legacy f32 CuTe DSL path preserved for non-pool usage AI-assisted merge resolution.
…e (bf16 path) (flashinfer-ai#2619) ## Summary - Adds `initial_state=[pool, HV, V, K]` and `initial_state_indices=[B]` parameters to `gated_delta_rule_decode_pretranspose`, allowing callers to pass an SSM state pool + per-batch indices directly instead of doing gather/scatter around the kernel. - The kernel reads and writes the pool in-place, so no separate scatter step is needed after the call. - Only supported via the bf16 fast path (`gdn_decode_klast_bf16_state`): bfloat16 state, T in 1..4, K=V=128. ```python Callsite Before: state = ssm_states[cache_indices] # gather out, state = kernel(state=state, ...) ssm_states[cache_indices] = state # scatter Callsite After: out, _ = kernel(initial_state=ssm_states, initial_state_indices=cache_indices, ...) # ssm_states updated in-place, no scatter needed ``` Changes - gdn_decode.py: add initial_state/initial_state_indices params; route pool path through the bf16 kernel - gdn_decode_bf16_state.py: add gH0_indices to all 3 kernels and 5 launch functions; gated_delta_rule now uses identity indices when initial_state_indices=None - tests/gdn/test_decode_delta_rule.py: add pool+indices test against gather→direct-state reference <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional external initial-state pool with per-batch indexing for GDN decode, plus an optimized BF16 fast path for pool-based state to improve inference performance; backward compatible (parameters are optional). * **Bug Fixes** * Stronger validation and gating to ensure correct path selection and in-place state updates. * **Documentation** * Clarified docstrings describing pool layout and index behavior. * **Tests** * Added tests covering the pool+indices pretranspose decode path and in-place update consistency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Summary
initial_state=[pool, HV, V, K]andinitial_state_indices=[B]parametersto
gated_delta_rule_decode_pretranspose, allowing callers to pass an SSM statepool + per-batch indices directly instead of doing gather/scatter around the kernel.
after the call.
gdn_decode_klast_bf16_state): bfloat16state, T in 1..4, K=V=128.
Changes
path through the bf16 kernel
functions; gated_delta_rule now uses identity indices when
initial_state_indices=None
gather→direct-state reference
Summary by CodeRabbit