Skip to content

Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel#3145

Open
wenscarl wants to merge 3 commits intoflashinfer-ai:mainfrom
wenscarl:shuw/gdn_mtp_fix
Open

Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel#3145
wenscarl wants to merge 3 commits intoflashinfer-ai:mainfrom
wenscarl:shuw/gdn_mtp_fix

Conversation

@wenscarl
Copy link
Copy Markdown
Collaborator

@wenscarl wenscarl commented Apr 22, 2026

co-authored by @YAMY1234

Problem

gdn_decode_bf16state_mtp_kernel crashes with an out-of-bounds GPU memory write when
intermediate_states_buffer is provided and pool_size > B with initial_state_indices
pointing to upper pool slots (the normal serving scenario).

Affected path: flashinfer/gdn_kernels/gdn_decode_bf16_state.py:1873

```python
Before (wrong)
flat_idx = cache_idx * T * HV + i_t * HV + i_hv

After (correct)
flat_idx = i_n * T * HV + i_t * HV + i_hv
```

Root Cause

When intermediate_states support was added to the BF16 state kernel, the author reused
the cache_idx-based addressing pattern from the persistent state pool access:

```python
flat_state_idx = cache_idx * HV + i_hv # correct: h0_source is pool-scoped
```

...and extended it by analogy to add a T dimension for intermediate_states. This is
wrong because the two buffers have different ownership semantics:

  • h0_source (the state pool) is pool-scoped — persists across decode steps, one slot
    per concurrent request in the system → correctly indexed by cache_idx (pool slot)
  • intermediate_states is batch-scoped — a per-forward-pass output capturing states at
    each of the T steps → should be indexed by i_n (batch position)

The float32 counterpart gdn_decode_mtp.py has always used i_n correctly. The BF16
kernel diverged when it was written.

The bug was invisible in existing tests because they always set pool_size = batch_size
with initial_state_indices = arange(B), making cache_idx == i_n identically. The
buggy and correct indexing produce the same result in that configuration. The docstring
describing the buffer shape as [pool_size, T, HV, V, K] further reinforced the incorrect
mental model.

The crash only manifests in the realistic serving scenario where pool_size >> B and
initial_state_indices contains values ≥ B, causing cache_idx-based writes to go
beyond the end of a batch-sized buffer.

Fix

Change cache_idx to i_n at the intermediate_states indexing site, and update the
docstring to reflect the correct buffer shape [B, T, HV, V, K] instead of
[pool_size, T, HV, V, K].

Test Changes

The existing test_gdn_decode_bf16_state_mtp_kernel always used pool_size = batch_size,
which masked the bug entirely. Two changes are made:

  1. pool_size_multiplier parameter added to the helper
    _test_gdn_decode_bf16_state_mtp_kernel. When > 1, it sets
    pool_size = batch_size * pool_size_multiplier and assigns each batch entry to an upper
    pool slot (initial_state_indices = arange(B) + pool_size - B), so cache_idx >= B
    for every entry. The intermediate_states_buffer is allocated with batch_size as its
    first dimension — the semantically correct size — which is smaller than pool_size.
    With the buggy cache_idx indexing the kernel writes out of bounds and crashes; after
    the fix it produces results matching the reference.

  2. @pytest.mark.parametrize("pool_size_multiplier", [1, 4]) added to the public
    test, doubling the existing matrix with a realistic pool-vs-batch configuration. The
    pool_size_multiplier=4, cache_intermediate_states=True cases are the ones that
    directly catch this bug.

Summary by CodeRabbit

  • Bug Fixes

    • Fixed intermediate-state indexing so cached states are written and read per batch slice consistently.
  • Tests

    • Expanded tests to cover pool sizes larger than the batch (parameterized), including remapped indices and updated assertions for those scenarios.
  • Documentation

    • Updated API/docs and launcher expectations to reflect the new intermediate-states buffer shape.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 400f27e5-6161-4bbe-b8f4-2c701e809f88

📥 Commits

Reviewing files that changed from the base of the PR and between d4b90120cef035e49f10efd182f0e9ec44fd6db0 and ec69283.

📒 Files selected for processing (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

📝 Walkthrough

Walkthrough

Adjusts the MTP BF16 decode intermediate-state layout and write indexing to use batch-addressed slices (B * T * HV) instead of pool-addressed slots; tests are extended to cover pool sizes larger than the batch to validate the new indexing.

Changes

Cohort / File(s) Summary
Kernel layout & indexing
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Change documented intermediate_states logical layout from pool-addressed to batch-addressed; when cache_intermediate_states is enabled, write indexing in gdn_decode_bf16state_mtp_kernel uses the launch batch index (i_n) for flat_idx instead of the pooled cache_idx.
Launcher / API docs
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Update public launcher/launcher API signatures and docs to expect intermediate_states_buffer shaped [B, T, HV, V, K].
Tests — pool-size coverage
tests/gdn/test_decode_delta_rule.py
Parametrize MTP BF16-state test with pool_size_multiplier (e.g., 1,4); enlarge pool when multiplier>1, remap initial_state_indices into upper pool region, allocate intermediate_states_buffer with leading dim [batch_size], and update reference selection and assertion messages to include pool_multiplier context.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • kahyunnam
  • yongwww

Poem

🐰 I hopped through memory, counted every lane,

Swapped a pooled address for an i_n name,
I nudged the tests to make pools grow wide,
So each token's state now finds its proper side,
A tiny hop, but indices aligned.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main fix: correcting an out-of-bounds crash caused by incorrect intermediate_states indexing in the GDN decode MTP kernel.
Description check ✅ Passed The description is comprehensive, addressing problem statement, root cause analysis, the specific fix, and test changes. However, it does not follow the provided template structure with sections like '📌 Description', '🔍 Related Issues', and '🚀 Pull Request Checklist'.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in the gdn_decode_bf16state_mtp_kernel where intermediate states were incorrectly indexed using the pool slot index (cache_idx) instead of the batch index (i_n). To prevent regressions, the test suite has been updated to include scenarios where the pool size is larger than the batch size, specifically by introducing a pool_size_multiplier. Feedback was provided to optimize the test code by removing unnecessary .cpu() and .clone() calls when indexing GPU tensors.

ref_state = input_state_ref_bf16.clone()
# Reference: step through tokens with bf16 state.
# Select only the batch entries' initial states from the pool.
ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to .cpu() on initial_state_indices is unnecessary because both the indices and the tensor being indexed (input_state_ref_bf16) are already on the GPU. Moving indices to the CPU just to index a GPU tensor is inefficient as it may trigger unnecessary host-device synchronization. Additionally, indexing with a tensor in PyTorch always creates a copy, so the .clone() call is redundant.

Suggested change
ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone()
ref_state = input_state_ref_bf16[initial_state_indices]

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gdn_kernels/gdn_decode_bf16_state.py (2)

2568-2576: ⚠️ Potential issue | 🟠 Major

Validate the dimensions required by the new flat index.

Line 1880 assumes a [B, T, HV, V, K] layout flattened with T * HV stride. Today buffer_size < B can still OOB, and cache_steps > T is accepted even though the kernel will write with the wrong batch stride.

🛡️ Proposed validation fix
         buffer_size = intermediate_states_buffer.shape[0]
         cache_steps = intermediate_states_buffer.shape[1]
-        assert cache_steps >= T, (
-            f"intermediate_states_buffer dim 1 ({cache_steps}) must be >= T={T}"
+        assert buffer_size >= B, (
+            f"intermediate_states_buffer dim 0 ({buffer_size}) must be >= B={B}"
+        )
+        assert cache_steps == T, (
+            f"intermediate_states_buffer dim 1 ({cache_steps}) must equal T={T}"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2568 - 2576,
The code flattens intermediate_states_buffer assuming a [B, T, HV, V, K] layout
and a batch stride of T*HV, but it doesn't validate that buffer_size and
cache_steps match the expected B and T (allowing OOB when buffer_size < B or
silent misuse when cache_steps != T). Add explicit validations before reshaping:
assert intermediate_states_buffer.dim() == 5, assert cache_steps == T, assert
buffer_size == B (or equivalently buffer_size * cache_steps == B * T if B is
known), and assert intermediate_states_buffer.shape[2:5] == (HV, V, K); keep the
dtype check for torch.bfloat16 and only then perform the reshape into
intermediate_states to ensure safe indexing in the kernel.

1097-1099: ⚠️ Potential issue | 🟡 Minor

Update stale intermediate-state shape docs.

The implementation now treats intermediate_states as batch-scoped, but these comments still advertise pool-scoped storage. That can mislead callers into allocating/reading the wrong shape.

📝 Proposed doc fix
-    intermediate_states: cute.Tensor,  # [pool_size * T * HV, V, K] as BF16 (or dummy)
+    intermediate_states: cute.Tensor,  # [B * T * HV, V, K] as BF16 (or dummy)
-    intermediate_states: cute.Tensor,  # [pool_size * T * HV, V, K] BF16 (or dummy)
+    intermediate_states: cute.Tensor,  # [B * T * HV, V, K] BF16 (or dummy)
-        intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16
+        intermediate_states_buffer: Optional [B, T, HV, V, K] bf16

Also applies to: 2024-2028, 2523-2528

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 1097 - 1099,
The doc comment for the intermediate_states parameter is stale: update the shape
description to reflect that intermediate_states is batch-scoped (not
pool-scoped). Replace the current "[pool_size * T * HV, V, K] as BF16 (or
dummy)" wording with a batch-scoped shape like "[batch_size * T * HV, V, K] as
BF16 (or dummy)" in the parameter docs for intermediate_states in
gdn_decode_bf16_state (and make the identical change in the other occurrences
around the sections noted), so callers allocate/read using batch_size rather
than pool_size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Around line 1815-1816: Run the project's pre-commit formatters
(black/ruff/pre-commit) on the changed test blocks to remove trailing whitespace
and apply ruff-format rewrites; specifically reformat the new assignment and
surrounding code that uses pool_size, batch_size, and pool_size_multiplier and
the other affected blocks referenced around the same test (the blocks near the
pool_size assignment and the later test sections), ensuring no trailing spaces
remain and ruff/black rules are satisfied.

---

Outside diff comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 2568-2576: The code flattens intermediate_states_buffer assuming a
[B, T, HV, V, K] layout and a batch stride of T*HV, but it doesn't validate that
buffer_size and cache_steps match the expected B and T (allowing OOB when
buffer_size < B or silent misuse when cache_steps != T). Add explicit
validations before reshaping: assert intermediate_states_buffer.dim() == 5,
assert cache_steps == T, assert buffer_size == B (or equivalently buffer_size *
cache_steps == B * T if B is known), and assert
intermediate_states_buffer.shape[2:5] == (HV, V, K); keep the dtype check for
torch.bfloat16 and only then perform the reshape into intermediate_states to
ensure safe indexing in the kernel.
- Around line 1097-1099: The doc comment for the intermediate_states parameter
is stale: update the shape description to reflect that intermediate_states is
batch-scoped (not pool-scoped). Replace the current "[pool_size * T * HV, V, K]
as BF16 (or dummy)" wording with a batch-scoped shape like "[batch_size * T *
HV, V, K] as BF16 (or dummy)" in the parameter docs for intermediate_states in
gdn_decode_bf16_state (and make the identical change in the other occurrences
around the sections noted), so callers allocate/read using batch_size rather
than pool_size.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c9765bd7-41f8-4853-982a-08bc53382b31

📥 Commits

Reviewing files that changed from the base of the PR and between fb3bb44 and bcdcdc4a78d2c223326ba4bc58d4735796f31cc1.

📒 Files selected for processing (2)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • tests/gdn/test_decode_delta_rule.py

Comment thread tests/gdn/test_decode_delta_rule.py Outdated
@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !590 has been created, and the CI pipeline #49302059 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/gdn/test_decode_delta_rule.py (1)

1883-1883: Minor: .cpu() round-trip is unnecessary for GPU tensor indexing.

input_state_ref_bf16 lives on CUDA, and initial_state_indices is a CUDA int32 tensor — PyTorch can index directly without moving indices to CPU (the gather then forces a D2H sync). Not a correctness issue (dtype int32 is accepted here), just a small cleanup.

♻️ Proposed simplification
-    ref_state = input_state_ref_bf16[initial_state_indices.cpu()].clone()
+    ref_state = input_state_ref_bf16[initial_state_indices.long()].clone()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_decode_delta_rule.py` at line 1883, The code does an
unnecessary device round-trip by calling .cpu() on initial_state_indices when
indexing a CUDA tensor; update the indexing of input_state_ref_bf16 by removing
the .cpu() call so ref_state =
input_state_ref_bf16[initial_state_indices].clone() (i.e., locate the expression
constructing ref_state and drop the .cpu() to allow direct CUDA-to-CUDA indexing
with initial_state_indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/gdn/test_decode_delta_rule.py`:
- Line 1883: The code does an unnecessary device round-trip by calling .cpu() on
initial_state_indices when indexing a CUDA tensor; update the indexing of
input_state_ref_bf16 by removing the .cpu() call so ref_state =
input_state_ref_bf16[initial_state_indices].clone() (i.e., locate the expression
constructing ref_state and drop the .cpu() to allow direct CUDA-to-CUDA indexing
with initial_state_indices).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8945e231-c4fc-4839-add7-028b5835aede

📥 Commits

Reviewing files that changed from the base of the PR and between bcdcdc4a78d2c223326ba4bc58d4735796f31cc1 and d4b90120cef035e49f10efd182f0e9ec44fd6db0.

📒 Files selected for processing (1)
  • tests/gdn/test_decode_delta_rule.py

@wenscarl wenscarl requested a review from kahyunnam April 23, 2026 20:42
Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.
The code fix looks correct to me. One small cleanup before merge: a few comments/docstrings still describe intermediate_states as pool-scoped, but this PR makes it batch-scoped.

Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
@wenscarl wenscarl requested a review from ameynaik-hub April 27, 2026 15:09
wenscarl and others added 3 commits April 27, 2026 10:09
Fix comment/docstring descriptions of `intermediate_states` to reflect
the batch-scoped shape [B * T * HV, V, K] / [B, T, HV, V, K] instead
of the outdated pool-scoped shape, as noted in PR review by ameynaik-hub.

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Apr 28, 2026
… PR flashinfer-ai#3145)

The ``intermediate_states_buffer`` is BATCH-scoped — shape ``[B, T, HV, V, K]``
— but both the wide_vec kernel (line 1115) and the mtp_ilp4 kernel
(line 621) were indexing it with ``cache_idx * T * HV + i_t * HV + i_hv``
where ``cache_idx`` is the POOL slot from ``initial_state_indices[i_n]``.

When ``pool_size > B`` (every realistic serving config) and
``initial_state_indices`` points at slots ``>= B`` (e.g. middle of a
1024-slot pool while servicing a B=32 batch), ``cache_idx * T * HV``
exceeds the buffer's ``B * T * HV`` extent and the
``cute.local_tile`` write goes off the end of the cache buffer ->
``cudaErrorIllegalAddress`` or silent memory corruption.

This is the same bug upstream PR flashinfer-ai#3145 fixed in the now-removed
``gdn_decode_bf16state_mtp_kernel``; both surviving BF16 kernels
inherited the incorrect pattern.

Fix:
- ``gdn_decode_bf16state_mtp_ilp4_kernel``: ``flat_idx = i_n * T * HV + ...``
  (was ``cache_idx * T * HV + ...``).
- ``gdn_wide_vec_kernel``: same.
- Dispatcher (both ``gated_delta_rule_mtp`` and
  ``gated_delta_rule_mtp_wide_vec``): assert
  ``intermediate_states_buffer.shape[0] == B`` and reshape using ``B``
  rather than ``buffer_size``. Also updates the comment / docstring to
  call out batch-scoped semantics explicitly.

Adds ``test_gdn_decode_bf16_state_mtp_pool_larger_than_batch`` (12
cases) which parametrizes ``pool_size_multiplier in {1, 4}`` and
``batch_size in {1, 8, 32}`` and ``seq_len in {2, 4}`` so both the
ilp4 path (B=1) and the wide_vec path (B=8/32) are exercised with
pool indices pointing at the upper half of a 4*B-slot pool. Verified
the test catches the bug: re-introducing the
``cache_idx * T * HV`` form makes the test fail with
``cudaErrorIllegalAddress``; reverting the line makes it pass again.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub mentioned this pull request Apr 28, 2026
5 tasks
@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented May 4, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !590 has been updated with latest changes, and the CI pipeline #50199162 is currently running. I'll report back once the pipeline job completes.

kahyunnam pushed a commit that referenced this pull request May 6, 2026
<!-- .github/pull_request_template.md -->

  ## Summary

  Replaces the legacy `gdn_decode_bf16state_cooprow_kernel` and the
  `gdn_decode_bf16state_mtp_kernel` (ILP=8) with a new
  **`gdn_wide_vec_kernel`** (LDG.E.128 / STG.E.128 fast path) plus a
  small-batch `mtp_ilp4` fallback. Drops ~1900 LOC of dead/unused code,
  adds split-pool support (#2905-compatible) to both surviving BF16
  kernels, and ships the OOB fix mirroring upstream PR #3145 — for the
  BF16 kernels that survived the cleanup.

  **Supersedes #3118.** That PR's perf delta (T=1 per-call overhead +
  pool+padding for the ILP kernel) is the first commit on this branch
  (`8a6e9819`). 
  
  ## What changes

  - **New kernel**: `gdn_wide_vec_kernel` — 128 threads/CTA = 8 groups
    × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per
    thread. Configurable `tile_v ∈ {32, 64, 128}` so the kernel covers
    small/medium/large `B*HV` work-unit sizes uniformly.
  - **Pool-only**: BF16 GDN dispatch is strictly pool-mode (matches
    the production serving contract). Wrapper
    `gated_delta_rule_decode_pretranspose` auto-promotes legacy non-pool
    callers internally — public API unchanged.
  - **Split-pool support** (PR #2905 contract): both surviving BF16
kernels (`gdn_wide_vec_kernel`, `gdn_decode_bf16state_mtp_ilp4_kernel`)
    natively support `output_state_indices != initial_state_indices`,
    with bit-equivalent single-pool behavior selected at compile time
    via `Constexpr[bool] same_pool` for zero-overhead dispatch.
  - **OOB fix (PR #3145 equivalent)**: `intermediate_states` is indexed
    by the per-call batch index `i_n` (not the pool-scoped `cache_idx`),
    so the buffer can be sized `[B, T, HV, V, K]` as production callers
    expect. Regression test catches the bug; pre-fix triggers
    `cudaErrorIllegalAddress` in <2 s.

  ## Removed (~1900 LOC of dead code)

  | Kernel | Why removed |
  |---|---|
| `gdn_decode_bf16state_cooprow_kernel` (~280 LOC) | Replaced by
wide_vec + ILP=4 MTP; had known correctness issues at small batch |
| `gdn_decode_bf16state_ilp_kernel` (~740 LOC) | Only reachable at HV<32
with B≥16 — not a Qwen3.5 shape; MTP path covers it |
| `gdn_decode_bf16state_mtp_kernel` (ILP=8) (~940 LOC) | After wide_vec
extension to split-pool + tile_v=32, mtp_kernel was unreachable |

  End-state BF16 surface = **2 `@cute.kernel`s in one file**:
  - `gdn_wide_vec_kernel` — production hot path
  - `gdn_decode_bf16state_mtp_ilp4_kernel` — small-batch fallback

  Both pool-only, both split-pool capable, both indexed batch-scoped.

  ## Speedup vs previous baseline

Baseline = pre-wide_vec dispatch (the `mtp_kernel` ILP=8 path, captured
  on this same branch by monkey-patching `_select_wide_vec_tile_v` to
  return `None` for every shape). Same harness, same hardware, same
  config — so the comparison isolates the kernel-level speedup that
  wide_vec + the cleanup deliver.

  Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50,
  T=1 invoked with `--update-state`, T≥2 invoked with
  `--cache-intermediate-states`. Kernel time in microseconds (CUPTI).

  ### Speedup (×, baseline / post-PR)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |
|-----|-------|-------|-------|-------|-------|-------|-------|-------|
| 1 | 1.03× | 1.04× | 1.03× | 1.03× | 1.00× | 1.02× | 1.02× | 1.01× |
| 4 | 0.97× | 1.23× | 1.10× | 1.12× | 1.11× | 1.11× | 1.14× | 1.14× |
| 8 | 1.08× | 1.11× | 1.11× | 1.12× | 1.13× | 1.15× | 1.15× | 1.14× |
| 16 | 1.04× | 1.09× | 1.11× | 1.13× | 1.13× | 1.12× | 1.11× | 1.10× |
| 32 | 1.06× | 1.12× | 1.10× | 1.11× | 1.10× | 1.09× | 1.09× | 1.09× |
| 64 | 1.04× | 1.11× | 1.08× | 1.09× | 1.06× | 1.07× | 1.08× | 1.06× |
| 128 | 1.04× | 1.11× | 1.07× | 1.09× | 1.06× | 1.06× | 1.07× | 1.07× |
| 256 | 1.04× | 1.11× | 1.07× | 1.09× | 1.07× | 1.06× | 1.07× | 1.07× |

  ### Time reduction (%)

| B | T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8 |

|-----|-------|--------|-------|--------|--------|--------|--------|--------|
| 1 | +2.8% | +3.5% | +3.2% | +3.3% | +0.1% | +1.7% | +1.8% | +0.7% |
| 4 | −3.2% | +18.7% | +9.5% | +10.7% | +9.7% | +10.2% | +12.3% | +12.5%
|
| 8 | +7.8% | +10.1% | +10.0%| +10.7% | +11.8% | +12.7% | +12.8% |
+12.4% |
| 16 | +3.4% | +8.5% | +10.0%| +11.2% | +11.4% | +10.7% | +10.3% | +9.4%
|
| 32 | +6.0% | +10.4% | +9.4% | +9.6% | +8.8% | +8.4% | +8.4% | +7.8% |
| 64 | +4.0% | +10.3% | +7.5% | +8.6% | +5.9% | +6.3% | +7.2% | +5.9% |
| 128 | +4.2% | +9.8% | +6.5% | +8.4% | +6.0% | +5.9% | +6.3% | +6.6% |
| 256 | +4.2% | +9.6% | +6.6% | +8.4% | +6.3% | +5.9% | +6.5% | +6.6% |

  ### Headline

- **T=1 production decode (B≥16)**: 4–6 % time reduction across the full
batch sweep — the Qwen3.5 hot path.
- **T≥2 with cache=ON (B≥4)**: 6–18 % time reduction at every shape.
Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×).
- **Tiny shapes (B=1)**: within ±3 % of baseline (kernel isn't
DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was
already efficient there).

  ### Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200)

  |  B  | T=1  | T=2  | T=3  | T=4  | T=5  | T=6  | T=7  | T=8  |
  |-----|------|------|------|------|------|------|------|------|
  |   1 | 1.25 | 1.21 | 1.49 | 1.63 | 1.45 | 1.55 | 1.64 | 1.70 |
  |   4 | 2.83 | 3.24 | 3.54 | 3.78 | 3.53 | 3.68 | 3.84 | 3.95 |
  |   8 | 3.97 | 4.09 | 4.40 | 4.54 | 4.42 | 4.55 | 4.59 | 4.61 |
  |  16 | 4.73 | 4.73 | 5.02 | 5.03 | 4.95 | 4.91 | 4.92 | 4.87 |
  |  32 | 5.39 | 5.36 | 5.44 | 5.46 | 5.27 | 5.23 | 5.21 | 5.17 |
  |  64 | 5.83 | 5.76 | 5.80 | 5.77 | 5.45 | 5.44 | 5.45 | 5.33 |
  | 128 | 6.31 | 6.05 | 6.03 | 6.01 | 5.68 | 5.61 | 5.57 | 5.54 |
  | 256 | 6.57 | 6.23 | 6.20 | 6.17 | 5.85 | 5.74 | 5.72 | 5.66 |

Post-PR peaks at **6.57 TB/s = 82 % of B200 peak DRAM** (T=1 B=256
production decode shape).

  ### Split-pool

  With wide_vec now supporting split-pool natively, split-pool
  matches single-pool to within ±1 % at every measured shape.

  ## Tests

  > **513 passed, 0 failed in 18m18s** on B200.

Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP
tests, 12 new OOB regression tests covering `pool_size_multiplier ∈ {1,
4}` × `B ∈ {1, 8, 32}` × `T ∈ {2, 4}`, 12 wrapper-level split-pool
tests.

  ## Files changed (4)

- `flashinfer/gdn_decode.py` — wrapper auto-promotes BF16 non-pool →
pool
- `flashinfer/gdn_kernels/gdn_decode_bf16_state.py` — wide_vec inlined;
dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE
- `tests/gdn/test_decode_delta_rule.py` — split-pool + OOB regression
tests
  - `benchmarks/bench_gdn_decode.py` — `--pool-mode {single,split}` flag

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added `--pool-mode` option to benchmark tool for configuring state
pool allocation (`single` or `split` modes).

* **Tests**
* Expanded BF16 test coverage with regression tests for split-pool
semantics and out-of-bounds scenarios; improved batch-dimension handling
for intermediate-state comparisons.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants