Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel#3145
Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel#3145wenscarl wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📥 CommitsReviewing files that changed from the base of the PR and between d4b90120cef035e49f10efd182f0e9ec44fd6db0 and ec69283. 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdjusts the MTP BF16 decode intermediate-state layout and write indexing to use batch-addressed slices ( Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request fixes a bug in the gdn_decode_bf16state_mtp_kernel where intermediate states were incorrectly indexed using the pool slot index (cache_idx) instead of the batch index (i_n). To prevent regressions, the test suite has been updated to include scenarios where the pool size is larger than the batch size, specifically by introducing a pool_size_multiplier. Feedback was provided to optimize the test code by removing unnecessary .cpu() and .clone() calls when indexing GPU tensors.
| ref_state = input_state_ref_bf16.clone() | ||
| # Reference: step through tokens with bf16 state. | ||
| # Select only the batch entries' initial states from the pool. | ||
| ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone() |
There was a problem hiding this comment.
The call to .cpu() on initial_state_indices is unnecessary because both the indices and the tensor being indexed (input_state_ref_bf16) are already on the GPU. Moving indices to the CPU just to index a GPU tensor is inefficient as it may trigger unnecessary host-device synchronization. Additionally, indexing with a tensor in PyTorch always creates a copy, so the .clone() call is redundant.
| ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone() | |
| ref_state = input_state_ref_bf16[initial_state_indices] |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (2)
2568-2576:⚠️ Potential issue | 🟠 MajorValidate the dimensions required by the new flat index.
Line 1880 assumes a
[B, T, HV, V, K]layout flattened withT * HVstride. Todaybuffer_size < Bcan still OOB, andcache_steps > Tis accepted even though the kernel will write with the wrong batch stride.🛡️ Proposed validation fix
buffer_size = intermediate_states_buffer.shape[0] cache_steps = intermediate_states_buffer.shape[1] - assert cache_steps >= T, ( - f"intermediate_states_buffer dim 1 ({cache_steps}) must be >= T={T}" + assert buffer_size >= B, ( + f"intermediate_states_buffer dim 0 ({buffer_size}) must be >= B={B}" + ) + assert cache_steps == T, ( + f"intermediate_states_buffer dim 1 ({cache_steps}) must equal T={T}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2568 - 2576, The code flattens intermediate_states_buffer assuming a [B, T, HV, V, K] layout and a batch stride of T*HV, but it doesn't validate that buffer_size and cache_steps match the expected B and T (allowing OOB when buffer_size < B or silent misuse when cache_steps != T). Add explicit validations before reshaping: assert intermediate_states_buffer.dim() == 5, assert cache_steps == T, assert buffer_size == B (or equivalently buffer_size * cache_steps == B * T if B is known), and assert intermediate_states_buffer.shape[2:5] == (HV, V, K); keep the dtype check for torch.bfloat16 and only then perform the reshape into intermediate_states to ensure safe indexing in the kernel.
1097-1099:⚠️ Potential issue | 🟡 MinorUpdate stale intermediate-state shape docs.
The implementation now treats
intermediate_statesas batch-scoped, but these comments still advertise pool-scoped storage. That can mislead callers into allocating/reading the wrong shape.📝 Proposed doc fix
- intermediate_states: cute.Tensor, # [pool_size * T * HV, V, K] as BF16 (or dummy) + intermediate_states: cute.Tensor, # [B * T * HV, V, K] as BF16 (or dummy)- intermediate_states: cute.Tensor, # [pool_size * T * HV, V, K] BF16 (or dummy) + intermediate_states: cute.Tensor, # [B * T * HV, V, K] BF16 (or dummy)- intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16 + intermediate_states_buffer: Optional [B, T, HV, V, K] bf16Also applies to: 2024-2028, 2523-2528
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 1097 - 1099, The doc comment for the intermediate_states parameter is stale: update the shape description to reflect that intermediate_states is batch-scoped (not pool-scoped). Replace the current "[pool_size * T * HV, V, K] as BF16 (or dummy)" wording with a batch-scoped shape like "[batch_size * T * HV, V, K] as BF16 (or dummy)" in the parameter docs for intermediate_states in gdn_decode_bf16_state (and make the identical change in the other occurrences around the sections noted), so callers allocate/read using batch_size rather than pool_size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1815-1816: Run the project's pre-commit formatters
(black/ruff/pre-commit) on the changed test blocks to remove trailing whitespace
and apply ruff-format rewrites; specifically reformat the new assignment and
surrounding code that uses pool_size, batch_size, and pool_size_multiplier and
the other affected blocks referenced around the same test (the blocks near the
pool_size assignment and the later test sections), ensuring no trailing spaces
remain and ruff/black rules are satisfied.
---
Outside diff comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2568-2576: The code flattens intermediate_states_buffer assuming a
[B, T, HV, V, K] layout and a batch stride of T*HV, but it doesn't validate that
buffer_size and cache_steps match the expected B and T (allowing OOB when
buffer_size < B or silent misuse when cache_steps != T). Add explicit
validations before reshaping: assert intermediate_states_buffer.dim() == 5,
assert cache_steps == T, assert buffer_size == B (or equivalently buffer_size *
cache_steps == B * T if B is known), and assert
intermediate_states_buffer.shape[2:5] == (HV, V, K); keep the dtype check for
torch.bfloat16 and only then perform the reshape into intermediate_states to
ensure safe indexing in the kernel.
- Around line 1097-1099: The doc comment for the intermediate_states parameter
is stale: update the shape description to reflect that intermediate_states is
batch-scoped (not pool-scoped). Replace the current "[pool_size * T * HV, V, K]
as BF16 (or dummy)" wording with a batch-scoped shape like "[batch_size * T *
HV, V, K] as BF16 (or dummy)" in the parameter docs for intermediate_states in
gdn_decode_bf16_state (and make the identical change in the other occurrences
around the sections noted), so callers allocate/read using batch_size rather
than pool_size.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c9765bd7-41f8-4853-982a-08bc53382b31
📥 Commits
Reviewing files that changed from the base of the PR and between fb3bb44 and bcdcdc4a78d2c223326ba4bc58d4735796f31cc1.
📒 Files selected for processing (2)
flashinfer/gdn_kernels/gdn_decode_bf16_state.pytests/gdn/test_decode_delta_rule.py
|
/bot run |
kahyunnam
left a comment
There was a problem hiding this comment.
Please fix the pre-commit failure here: https://github.com/flashinfer-ai/flashinfer/actions/runs/24785336107/job/72527611010?pr=3145
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)
1883-1883: Minor:.cpu()round-trip is unnecessary for GPU tensor indexing.
input_state_ref_bf16lives on CUDA, andinitial_state_indicesis a CUDA int32 tensor — PyTorch can index directly without moving indices to CPU (the gather then forces a D2H sync). Not a correctness issue (dtype int32 is accepted here), just a small cleanup.♻️ Proposed simplification
- ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone() + ref_state = input_state_ref_bf16[initial_state_indices.long()].clone()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_delta_rule.py` at line 1883, The code does an unnecessary device round-trip by calling .cpu() on initial_state_indices when indexing a CUDA tensor; update the indexing of input_state_ref_bf16 by removing the .cpu() call so ref_state = input_state_ref_bf16[initial_state_indices].clone() (i.e., locate the expression constructing ref_state and drop the .cpu() to allow direct CUDA-to-CUDA indexing with initial_state_indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Line 1883: The code does an unnecessary device round-trip by calling .cpu() on
initial_state_indices when indexing a CUDA tensor; update the indexing of
input_state_ref_bf16 by removing the .cpu() call so ref_state =
input_state_ref_bf16[initial_state_indices].clone() (i.e., locate the expression
constructing ref_state and drop the .cpu() to allow direct CUDA-to-CUDA indexing
with initial_state_indices).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8945e231-c4fc-4839-add7-028b5835aede
📥 Commits
Reviewing files that changed from the base of the PR and between bcdcdc4a78d2c223326ba4bc58d4735796f31cc1 and d4b90120cef035e49f10efd182f0e9ec44fd6db0.
📒 Files selected for processing (1)
tests/gdn/test_decode_delta_rule.py
ameynaik-hub
left a comment
There was a problem hiding this comment.
Thanks for the fix.
The code fix looks correct to me. One small cleanup before merge: a few comments/docstrings still describe intermediate_states as pool-scoped, but this PR makes it batch-scoped.
Fix comment/docstring descriptions of `intermediate_states` to reflect the batch-scoped shape [B * T * HV, V, K] / [B, T, HV, V, K] instead of the outdated pool-scoped shape, as noted in PR review by ameynaik-hub. AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
249eaeb to
ec69283
Compare
… PR flashinfer-ai#3145) The ``intermediate_states_buffer`` is BATCH-scoped — shape ``[B, T, HV, V, K]`` — but both the wide_vec kernel (line 1115) and the mtp_ilp4 kernel (line 621) were indexing it with ``cache_idx * T * HV + i_t * HV + i_hv`` where ``cache_idx`` is the POOL slot from ``initial_state_indices[i_n]``. When ``pool_size > B`` (every realistic serving config) and ``initial_state_indices`` points at slots ``>= B`` (e.g. middle of a 1024-slot pool while servicing a B=32 batch), ``cache_idx * T * HV`` exceeds the buffer's ``B * T * HV`` extent and the ``cute.local_tile`` write goes off the end of the cache buffer -> ``cudaErrorIllegalAddress`` or silent memory corruption. This is the same bug upstream PR flashinfer-ai#3145 fixed in the now-removed ``gdn_decode_bf16state_mtp_kernel``; both surviving BF16 kernels inherited the incorrect pattern. Fix: - ``gdn_decode_bf16state_mtp_ilp4_kernel``: ``flat_idx = i_n * T * HV + ...`` (was ``cache_idx * T * HV + ...``). - ``gdn_wide_vec_kernel``: same. - Dispatcher (both ``gated_delta_rule_mtp`` and ``gated_delta_rule_mtp_wide_vec``): assert ``intermediate_states_buffer.shape[0] == B`` and reshape using ``B`` rather than ``buffer_size``. Also updates the comment / docstring to call out batch-scoped semantics explicitly. Adds ``test_gdn_decode_bf16_state_mtp_pool_larger_than_batch`` (12 cases) which parametrizes ``pool_size_multiplier in {1, 4}`` and ``batch_size in {1, 8, 32}`` and ``seq_len in {2, 4}`` so both the ilp4 path (B=1) and the wide_vec path (B=8/32) are exercised with pool indices pointing at the upper half of a 4*B-slot pool. Verified the test catches the bug: re-introducing the ``cache_idx * T * HV`` form makes the test fail with ``cudaErrorIllegalAddress``; reverting the line makes it pass again. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
|
/bot run |
<!-- .github/pull_request_template.md --> ## Summary Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code, adds split-pool support (#2905-compatible) to both surviving BF16 kernels, and ships the OOB fix mirroring upstream PR #3145 — for the BF16 kernels that survived the cleanup. **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead + pool+padding for the ILP kernel) is the first commit on this branch (`8a6e9819`). ## What changes - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers small/medium/large `B*HV` work-unit sizes uniformly. - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches the production serving contract). Wrapper `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool callers internally — public API unchanged. - **Split-pool support** (PR #2905 contract): both surviving BF16 kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`) natively support `output_state_indices != initial_state_indices`, with bit-equivalent single-pool behavior selected at compile time via `Constexpr[bool] same_pool` for zero-overhead dispatch. - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed by the per-call batch index `i_n` (not the pool-scoped `cache_idx`), so the buffer can be sized `[B, T, HV, V, K]` as production callers expect. Regression test catches the bug; pre-fix triggers `cudaErrorIllegalAddress` in <2 s. ## Removed (~1900 LOC of dead code) | Kernel | Why removed | |---|---| | `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by wide_vec + ILP=4 MTP; had known correctness issues at small batch | | `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32 with B≥16 — not a Qwen3.5 shape; MTP path covers it | | `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec extension to split-pool + tile_v=32, mtp_kernel was unreachable | End-state BF16 surface = **2 `@cute.kernel`s in one file**: - `gdn_wide_vec_kernel` — production hot path - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback Both pool-only, both split-pool capable, both indexed batch-scoped. ## Speedup vs previous baseline Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured on this same branch by monkey-patching `_select_wide_vec_tile_v` to return `None` for every shape). Same harness, same hardware, same config — so the comparison isolates the kernel-level speedup that wide_vec + the cleanup deliver. Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50, T=1 invoked with `--update-state`, T≥2 invoked with `--cache-intermediate-states`. Kernel time in microseconds (CUPTI). ### Speedup (×, baseline / post-PR) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|-------|-------|-------|-------|-------|-------|-------|-------| | 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× | | 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× | | 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× | | 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× | | 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× | | 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× | | 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× | | 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× | ### Time reduction (%) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|-------|--------|-------|--------|--------|--------|--------|--------| | 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% | | 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5% | | 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% | +12.4% | | 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4% | | 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% | | 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% | | 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% | | 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% | ### Headline - **T=1 production decode (B≥16)**: 4–6 % time reduction across the full batch sweep — the Qwen3.5 hot path. - **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape. Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×). - **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was already efficient there). ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200) | B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 | |-----|------|------|------|------|------|------|------|------| | 1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 | | 4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 | | 8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 | | 16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 | | 32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 | | 64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 | | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 | | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 | Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256 production decode shape). ### Split-pool With wide_vec now supporting split-pool natively, split-pool matches single-pool to within ±1 % at every measured shape. ## Tests > **513 passed, 0 failed in 18m18s** on B200. Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1, 4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool tests. ## Files changed (4) - `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool → pool - `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined; dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE - `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression tests - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--pool-mode` option to benchmark tool for configuring state pool allocation (`single` or `split` modes). * **Tests** * Expanded BF16 test coverage with regression tests for split-pool semantics and out-of-bounds scenarios; improved batch-dimension handling for intermediate-state comparisons. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
co-authored by @YAMY1234
Problem
gdn_decode_bf16state_mtp_kernelcrashes with an out-of-bounds GPU memory write whenintermediate_states_bufferis provided andpool_size > Bwithinitial_state_indicespointing to upper pool slots (the normal serving scenario).
Affected path:
flashinfer/gdn_kernels/gdn_decode_bf16_state.py:1873```python
Before (wrong)
flat_idx = cache_idx * T * HV + i_t * HV + i_hv
After (correct)
flat_idx = i_n * T * HV + i_t * HV + i_hv
```
Root Cause
When
intermediate_statessupport was added to the BF16 state kernel, the author reusedthe
cache_idx-based addressing pattern from the persistent state pool access:```python
flat_state_idx = cache_idx * HV + i_hv # correct: h0_source is pool-scoped
```
...and extended it by analogy to add a
Tdimension forintermediate_states. This iswrong because the two buffers have different ownership semantics:
h0_source(the state pool) is pool-scoped — persists across decode steps, one slotper concurrent request in the system → correctly indexed by
cache_idx(pool slot)intermediate_statesis batch-scoped — a per-forward-pass output capturing states ateach of the T steps → should be indexed by
i_n(batch position)The float32 counterpart
gdn_decode_mtp.pyhas always usedi_ncorrectly. The BF16kernel diverged when it was written.
The bug was invisible in existing tests because they always set
pool_size = batch_sizewith
initial_state_indices = arange(B), makingcache_idx == i_nidentically. Thebuggy and correct indexing produce the same result in that configuration. The docstring
describing the buffer shape as
[pool_size, T, HV, V, K]further reinforced the incorrectmental model.
The crash only manifests in the realistic serving scenario where
pool_size >> Bandinitial_state_indicescontains values ≥B, causingcache_idx-based writes to gobeyond the end of a batch-sized buffer.
Fix
Change
cache_idxtoi_nat theintermediate_statesindexing site, and update thedocstring to reflect the correct buffer shape
[B, T, HV, V, K]instead of[pool_size, T, HV, V, K].Test Changes
The existing
test_gdn_decode_bf16_state_mtp_kernelalways usedpool_size = batch_size,which masked the bug entirely. Two changes are made:
pool_size_multiplierparameter added to the helper_test_gdn_decode_bf16_state_mtp_kernel. When> 1, it setspool_size = batch_size * pool_size_multiplierand assigns each batch entry to an upperpool slot (
initial_state_indices = arange(B) + pool_size - B), socache_idx >= Bfor every entry. The
intermediate_states_bufferis allocated withbatch_sizeas itsfirst dimension — the semantically correct size — which is smaller than
pool_size.With the buggy
cache_idxindexing the kernel writes out of bounds and crashes; afterthe fix it produces results matching the reference.
@pytest.mark.parametrize("pool_size_multiplier", [1, 4])added to the publictest, doubling the existing matrix with a realistic pool-vs-batch configuration. The
pool_size_multiplier=4, cache_intermediate_states=Truecases are the ones thatdirectly catch this bug.
Summary by CodeRabbit
Bug Fixes
Tests
Documentation