Skip to content

Ameyn/wide vec t1#3147

Merged
kahyunnam merged 31 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/wide_vec_t1
May 6, 2026
Merged

Ameyn/wide vec t1#3147
kahyunnam merged 31 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/wide_vec_t1

Conversation

@ameynaik-hub
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub commented Apr 22, 2026

Summary

Replaces the legacy gdn_decode_bf16state_cooprow_kernel and the
gdn_decode_bf16state_mtp_kernel (ILP=8) with a new
gdn_wide_vec_kernel (LDG.E.128 / STG.E.128 fast path) plus a
small-batch mtp_ilp4 fallback. Drops ~1900 LOC of dead/unused code,
adds split-pool support (#2905-compatible) to both surviving BF16
kernels, and ships the OOB fix mirroring upstream PR #3145 — for the
BF16 kernels that survived the cleanup.

Supersedes #3118. That PR's perf delta (T=1 per-call overhead +
pool+padding for the ILP kernel) is the first commit on this branch
(8a6e9819).

What changes

  • New kernel: gdn_wide_vec_kernel — 128 threads/CTA = 8 groups
    × 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per
    thread. Configurable tile_v ∈ {32, 64, 128} so the kernel covers
    small/medium/large B*HV work-unit sizes uniformly.
  • Pool-only: BF16 GDN dispatch is strictly pool-mode (matches
    the production serving contract). Wrapper
    gated_delta_rule_decode_pretranspose auto-promotes legacy non-pool
    callers internally — public API unchanged.
  • Split-pool support (PR feat(gdn): separate input and output pool indices #2905 contract): both surviving BF16
    kernels (gdn_wide_vec_kernel, gdn_decode_bf16state_mtp_ilp4_kernel)
    natively support output_state_indices != initial_state_indices,
    with bit-equivalent single-pool behavior selected at compile time
    via Constexpr[bool] same_pool for zero-overhead dispatch.
  • OOB fix (PR Fix OOB crash in intermediate_states indexing for GDN decode MTP kernel #3145 equivalent): intermediate_states is indexed
    by the per-call batch index i_n (not the pool-scoped cache_idx),
    so the buffer can be sized [B, T, HV, V, K] as production callers
    expect. Regression test catches the bug; pre-fix triggers
    cudaErrorIllegalAddress in <2 s.

Removed (~1900 LOC of dead code)

Kernel Why removed
gdn_decode_bf16state_cooprow_kernel (~280 LOC) Replaced by wide_vec + ILP=4 MTP; had known correctness issues at small batch
gdn_decode_bf16state_ilp_kernel (~740 LOC) Only reachable at HV<32 with B≥16 — not a Qwen3.5 shape; MTP path covers it
gdn_decode_bf16state_mtp_kernel (ILP=8) (~940 LOC) After wide_vec extension to split-pool + tile_v=32, mtp_kernel was unreachable

End-state BF16 surface = 2 @cute.kernels in one file:

  • gdn_wide_vec_kernel — production hot path
  • gdn_decode_bf16state_mtp_ilp4_kernel — small-batch fallback

Both pool-only, both split-pool capable, both indexed batch-scoped.

Speedup vs previous baseline

Baseline = pre-wide_vec dispatch (the mtp_kernel ILP=8 path, captured
on this same branch by monkey-patching _select_wide_vec_tile_v to
return None for every shape). Same harness, same hardware, same
config — so the comparison isolates the kernel-level speedup that
wide_vec + the cleanup deliver.

Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50,
T=1 invoked with --update-state, T≥2 invoked with
--cache-intermediate-states. Kernel time in microseconds (CUPTI).

Speedup (×, baseline / post-PR)

B T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 1.03× 1.04× 1.03× 1.03× 1.00× 1.02× 1.02× 1.01×
4 0.97× 1.23× 1.10× 1.12× 1.11× 1.11× 1.14× 1.14×
8 1.08× 1.11× 1.11× 1.12× 1.13× 1.15× 1.15× 1.14×
16 1.04× 1.09× 1.11× 1.13× 1.13× 1.12× 1.11× 1.10×
32 1.06× 1.12× 1.10× 1.11× 1.10× 1.09× 1.09× 1.09×
64 1.04× 1.11× 1.08× 1.09× 1.06× 1.07× 1.08× 1.06×
128 1.04× 1.11× 1.07× 1.09× 1.06× 1.06× 1.07× 1.07×
256 1.04× 1.11× 1.07× 1.09× 1.07× 1.06× 1.07× 1.07×

Headline

  • T=1 production decode (B≥16): 4–6 % time reduction across the full batch sweep — the Qwen3.5 hot path.
  • T≥2 with cache=ON (B≥4): 6–18 % time reduction at every shape. Best at small-T / mid-batch (B=4 T=2: 1.23×; B=8 T=6: 1.15×).
  • Tiny shapes (B=1): within ±3 % of baseline (kernel isn't DRAM-bound; small fixed-cost overheads dominate; the ILP=4 fallback was already efficient there).

Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200)

B T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 1.25 1.21 1.49 1.63 1.45 1.55 1.64 1.70
4 2.83 3.24 3.54 3.78 3.53 3.68 3.84 3.95
8 3.97 4.09 4.40 4.54 4.42 4.55 4.59 4.61
16 4.73 4.73 5.02 5.03 4.95 4.91 4.92 4.87
32 5.39 5.36 5.44 5.46 5.27 5.23 5.21 5.17
64 5.83 5.76 5.80 5.77 5.45 5.44 5.45 5.33
128 6.31 6.05 6.03 6.01 5.68 5.61 5.57 5.54
256 6.57 6.23 6.20 6.17 5.85 5.74 5.72 5.66

Post-PR peaks at 6.57 TB/s = 82 % of B200 peak DRAM (T=1 B=256 production decode shape).

Split-pool

With wide_vec now supporting split-pool natively, split-pool
matches single-pool to within ±1 % at every measured shape.

Tests

513 passed, 0 failed in 18m18s on B200.

Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP tests, 12 new OOB regression tests covering pool_size_multiplier ∈ {1, 4} × B ∈ {1, 8, 32} × T ∈ {2, 4}, 12 wrapper-level split-pool tests.

Files changed (4)

  • flashinfer/gdn_decode.py — wrapper auto-promotes BF16 non-pool → pool
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py — wide_vec inlined; dead kernels removed; split-pool plumbing; OOB fix; same_pool DCE
  • tests/gdn/test_decode_delta_rule.py — split-pool + OOB regression tests
  • benchmarks/bench_gdn_decode.py--pool-mode {single,split} flag

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added --pool-mode option to benchmark tool for configuring state pool allocation (single or split modes).
  • Tests

    • Expanded BF16 test coverage with regression tests for split-pool semantics and out-of-bounds scenarios; improved batch-dimension handling for intermediate-state comparisons.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

Implements support for separate input and output state indices in BF16-accelerated gated-delta-rule decoding, enabling split-pool dispatch where outputs are written to different pool slots than inputs are read from. Unifies both T==1 and T>1 code paths to route through consistent pool/index handling, updates the benchmark wrapper with a pool_mode parameter, and adds comprehensive test coverage including split-pool semantics and edge-case validation.

Changes

BF16 Split-Pool State Dispatch

Layer / File(s) Summary
Core Implementation
flashinfer/gdn_decode.py
Unified pool/index handling for both T==1 and T>1 BF16 paths; gated_delta_rule_decode_pretranspose now routes non-pool state as a synthetic pool and forwards both initial_state_indices and output_state_indices into BF16 state kernel calls, enabling split-pool dispatch.
Benchmark Wrapper
benchmarks/bench_gdn_decode.py
gdn_decode_bf16_state_wrapper adds output_state_indices argument; bench_gdn_decode_bf16_state adds pool_mode parameter ("single" or "split"); CLI --pool-mode flag propagates pool mode from command line through benchmark function.
Test Coverage
tests/gdn/test_decode_delta_rule.py
State kernel tests now explicitly pass initial_state_indices=torch.arange(batch_size) for pool-only semantics; BF16 T==1 test expanded to include num_v_heads=64 variant; BF16 MTP intermediate-state comparison chunked to prevent OOM; new wide-vector BF16 MTP test uses monkeypatch to exercise varied tile_v parameters; two new BF16 MTP regression tests validate split-pool semantics and OOB index handling.

Estimated Code Review Effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly Related Issues

Possibly Related PRs

Suggested Labels

ready, run-ci

Suggested Reviewers

  • yzh119
  • bkryu
  • nvmbreughe
  • jimmyzho
  • yongwww
  • kahyunnam

Poem

🐰 Split the pool in two, we say,
Read from here, write far away,
Indices dance both high and low,
BF16 kernels steal the show!
Tests now catch the edge-case flaws,
Split-pool state obeys new laws!

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Title check ⚠️ Warning The PR title 'Ameyn/wide vec t1' is vague and uses non-descriptive terms that do not clearly convey the main purpose of this substantial refactoring. Use a clear, descriptive title such as 'Replace legacy BF16 GDN kernels with wide_vec path and add split-pool support' or 'Add gdn_wide_vec_kernel with split-pool support and OOB fix'.
Docstring Coverage ⚠️ Warning Docstring coverage is 64.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is comprehensive and detailed, covering all required sections with extensive technical documentation, performance data, and implementation details.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get your free trial and get 200 agent minutes per Slack user (a $50 value).


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a wide-vector BF16 GDN MTP decode kernel and an ILP=4 variant to optimize performance across various batch sizes and work units. It also refactors the dispatch logic to support pool-based state management with indices and includes updated tests. Feedback from the review focuses on the significant code duplication across the new kernels and the repetitive nature of the memory write-back logic, suggesting that these be refactored into parameterized implementations or loops to improve maintainability.

Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py Outdated
Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py Outdated
ameynaik-hub and others added 22 commits April 27, 2026 09:33
…to ILP kernel

Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:

1) Per-call Python overhead fix
   Move torch.arange / torch.empty / from_dlpack / get_device_capability out
   of the steady-state call path. These were previously called on every
   invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
   of CUPTI-visible overhead per call at small BS. All default-tensor
   allocation and dlpack conversion is now done once, inside the
   `cache_key not in _compiled_kernels_*` block, and cached alongside the
   compiled kernel. Steady-state calls pass raw torch tensors directly to
   the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
   per-call torch.cuda.get_device_capability().

2) Pool + padding support on the ILP kernel
   gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
   accepts h0_indices and h0_out_indices, matching the MTP kernel's
   signature. Negative indices redirect to pool slot 0 (null buffer);
   writes go to a separate flat_write_idx so input and output pool slots
   can differ. The ILP launcher and gated_delta_rule wrapper thread the
   new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
   collapsed so pool+indices T=1 calls no longer detour through the
   heavier MTP kernel.

Design choice: kernel always takes indices (no constexpr switch).

Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200

Command:
    python benchmarks/bench_gdn_decode.py \
        --batch-size 1 4 8 16 32 64 128 256 512 \
        --num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
        --head-size 128 --dtype bfloat16 --warmup 20 --iters 200

Bf16State column results (us):

  BS | time
   1 |   3.71
   4 |   5.89
   8 |   9.18
  16 |  14.98
  32 |  26.56
  64 |  48.24
 128 |  89.66
 256 | 172.35
 512 | 337.60

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…stic fix

Three coordinated changes to the BF16 GDN MTP decode path on B200:

1. New kernel `gated_delta_rule_mtp_wide_vec`
   (flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py):
   - 128 threads/CTA organised as 8 groups of 16 threads (baseline: 4 of 32)
   - vec=8 BF16 per thread -> LDG.128 / STG.128 (baseline: vec=4 -> .64)
   - ILP=4 V-rows/thread, 4-stage butterfly over 16-lane subgroup
   - No TMA, no persistent CTAs; the win is purely wider LSU transactions
   - Peaks at 77.6 % DRAM SOL at B=256, T=2, HV=64

2. Auto-dispatcher inside `gated_delta_rule_mtp`
   (flashinfer/gdn_kernels/gdn_decode_bf16_state.py):
   - New module-level `_WIDE_VEC_WORK_UNITS_THRESHOLD = 1024`
   - Routes to wide_vec when `B*HV >= 1024 and T >= 2 and K == V == 128`
   - T=1 and small-batch callers keep the existing baseline path verbatim,
     so the public API is source-and-ABI compatible.

3. B=4 HV=64 T=2 heuristic fix in `_get_bf16_mtp_config`:
   - Baseline's T=2 code path (gdn_decode_bf16_state.py) recomputes g/beta
     inline instead of reading from sGB (only T>2 pre-populates sGB). The
     inline softplus+exp+log+sigmoid stalls the ILP=8 pipeline at small
     work_units.
   - Fix: when `seq_len == 2 and work_units <= 256`, return (tile_v, 4)
     instead of (tile_v, 8). ILP=4 gives ~62 % occupancy vs ILP=8's ~37 %,
     covering the recompute latency.
   - Measured: B=4 HV=64 T=2: 11.20 us -> 9.63 us (1.17x).

Benchmark - Qwen3.5-397B-A17B Gated DeltaNet shape
(B200, HV=64, H_Q=H_K=16, K=V=128, BF16, cache_intermediate=True,
disable_state_update=True; T=1 uses state-update ON, no intermediate caching;
measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state with
CUPTI, 5 warmup + 50 bench iters per cell):

Wall-time (us), dispatcher's best-of-both output
(= baseline for B<=8, wide_vec for B>=16 at this HV; T=1 always baseline):

    B \ T       T=1      T=2      T=3      T=4      T=5      T=6      T=7      T=8
    -----    -------  -------  -------  -------  -------  -------  -------  -------
     1         3.46     5.44     5.79     6.67     8.83     9.79    10.61    11.50
     2         4.25     6.40     7.14     8.27    10.66    11.97    13.07    14.18
     4         5.79     9.63    10.53    12.51    16.03    18.11    20.32    22.22
     8         8.96    13.63    17.02    21.14    26.11    30.02    34.16    37.92
    16        15.22    21.20    27.23    33.73    41.23    47.89    55.15    62.11
    32        26.37    37.81    49.98    62.93    78.46    91.20   103.73   117.58
    64        47.76    70.69    93.76   117.92   146.42   172.80   197.82   225.44
   128        90.38   135.17   180.46   226.99   278.78   329.41   378.18   432.17
   256       173.65   262.98   351.06   440.75   542.37   641.63   739.60   846.40
   512       337.98   516.24   691.10   869.02    ERROR    ERROR    ERROR    ERROR

Wall-time (us), pre-session baseline (forced by monkey-patching
_WIDE_VEC_WORK_UNITS_THRESHOLD to 10**9):

    B \ T       T=1      T=2      T=3      T=4      T=5      T=6      T=7      T=8
    -----    -------  -------  -------  -------  -------  -------  -------  -------
     1         3.46     5.44     5.79     6.67     8.83     9.79    10.61    11.50
     2         4.25     6.40     7.14     8.27    10.66    11.97    13.07    14.18
     4         5.79     9.63    10.53    12.51    16.03    18.11    20.32    22.22
     8         8.96    13.63    17.02    21.14    26.11    30.02    34.16    37.92
    16        15.22    23.71    30.02    37.84    46.82    54.50    61.95    69.89
    32        26.37    42.56    54.75    69.01    85.41    99.74   114.13   128.50
    64        47.76    79.18   101.38   129.02   159.44   187.44   216.08   244.03
   128        90.38   149.34   193.09   247.38   305.87   361.33   417.20   473.81
   256       173.65   289.12   375.31   481.41   596.72   705.98   816.64   930.16
   512       337.98   568.32   741.25   952.97    ERROR    ERROR    ERROR    ERROR

Speedup (baseline / dispatcher). <=1.00x means dispatcher keeps baseline.
B=4 T=2 "1.00x" reflects the *new* heuristic (ilp=4) already in effect for
both columns; the pre-heuristic-fix baseline at that cell was 11.20 us
(1.16x slower than the post-fix 9.63 us).

    B \ T      T=1    T=2    T=3    T=4    T=5    T=6    T=7    T=8
    -----    -----  -----  -----  -----  -----  -----  -----  -----
     1        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
     2        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
     4        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
     8        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
    16        1.00x  1.12x  1.10x  1.12x  1.14x  1.14x  1.12x  1.13x
    32        1.00x  1.13x  1.10x  1.10x  1.09x  1.09x  1.10x  1.09x
    64        1.00x  1.12x  1.08x  1.09x  1.09x  1.08x  1.09x  1.08x
   128        1.00x  1.10x  1.07x  1.09x  1.10x  1.10x  1.10x  1.10x
   256        1.00x  1.10x  1.07x  1.09x  1.10x  1.10x  1.10x  1.10x
   512        1.00x  1.10x  1.07x  1.10x  ERROR  ERROR  ERROR  ERROR

DRAM SOL (%, of 8.0 TB/s B200 peak) at the dispatcher's output:

    B \ T     T=1    T=2    T=3    T=4    T=5    T=6    T=7    T=8
    -----    -----  -----  -----  -----  -----  -----  -----  -----
     1       15.3   14.6   18.4   20.0   18.1   19.1   20.1   20.9
     2       24.9   24.9   29.8   32.2   30.0   31.2   32.6   33.9
     4       36.6   33.1   40.4   42.6   39.9   41.2   42.0   43.2
     8       47.3   46.8   50.0   50.4   49.0   49.7   50.0   50.6
    16       55.7   60.1   62.5   63.2   62.0   62.3   61.9   61.8
    32       64.3   67.4   68.1   67.7   65.2   65.5   65.8   65.3
    64       70.9   72.1   72.6   72.3   69.9   69.1   69.0   68.1
   128       75.0   75.4   75.5   75.1   73.4   72.5   72.2   71.1
   256       78.1   77.6   77.6   77.3   75.5   74.4   73.8   72.6
   512       80.2   79.0   78.8   78.4   ERROR  ERROR  ERROR  ERROR

Peak SOL: 80.2 % at B=512 T=1 (T=1 path unchanged; reflects the earlier
`7b7f1ac3` heuristic). Peak SOL on wide_vec-dispatched cells: 79.0 % at
B=512 T=2. Baseline-only peak at the same cells was ~71 % (B=512 T=2).
B=512 T>=5 hits the known cudaErrorIllegalAddress from
results/2026-04-03/benchmark_results.md (unrelated; pre-existing).

Correctness: 282 / 282 pytest configs pass on B=1..256, T=2..8, HV in {32, 64},
cache_intermediate_states in {True, False} via the official
`_test_gdn_decode_bf16_state_mtp_kernel` helper, reused across both
`test_gdn_decode_bf16_state_mtp_kernel` (baseline, unchanged) and the new
`test_gdn_decode_bf16_state_wide_vec_mtp_kernel` (monkey-patched symbol).

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After rebasing the wide_vec dispatcher onto bf16_baseline_fix, the original
T=1 / T=2 small-batch heuristic from 192fa39d / 7b7f1ac3 regressed because
the _get_bf16_mtp_config helper was dropped in the conflict resolution
(upstream only had _select_tile_v_for_mtp for ILP=8 tile sizing). This
commit re-adds the ILP=4 variant so the PR ships without the B=4 T=2
regression that would otherwise show in CI.

Changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py:

1. New @cute.kernel gdn_decode_bf16state_mtp_ilp4_kernel. Same math as the
   ILP=8 kernel, but processes 4 V-rows/group/iter instead of 8.
   - ILP=4: ~48 regs/thread -> ~62 % occupancy (vs ILP=8's ~37 %)
   - Minimal signature: single h0_indices (read == write). Split-pool writes
     are delegated to ILP=8 by the dispatcher; the ILP=4 path never needs
     h0_out_indices, so we skip the extra plumbing.

2. New @cute.jit run_gdn_decode_bf16state_mtp_ilp4 launcher, mirrors the
   existing ILP=8 launcher's tile_v/grid/SMEM computation but calls the
   new kernel.

3. New _get_bf16_mtp_config(batch_size, seq_len, num_v_heads, v_dim) helper:
     - work_units <= 128: (min(16, v_dim), 4)  # B<=2 at HV=64
     - seq_len == 2 and work_units <= 256: (tile_v, 4)  # covers B=4 HV=64
     - else: (tile_v, 8)
   The T=2 branch compensates for the ILP=8 pipeline stall from the inline
   g/beta recompute path (sGB is only populated for T > 2 in the ILP=8
   kernel; at small work_units ILP=8's 37 % occupancy cannot hide the
   softplus+exp+log+sigmoid latency).

4. gated_delta_rule_mtp dispatcher:
     - When output_state_indices is None: pick (tile_v, ilp_rows) via
       _get_bf16_mtp_config and route to the matching launcher.
     - When output_state_indices is not None: force ilp_rows=8 so the
       split-pool write path (h0_out_indices) stays on the ILP=8 kernel
       that supports it.
     - Cache key now includes ilp_rows so the two launchers don't collide.

Perf verification (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI,
5 warmup + 50 bench iters; cache_intermediate=True, disable_state_update=True
for T>=2; T=1 uses state-update ON):

              post-rebase before       post ILP=4 re-add
    B    T       (us)                   (us)
    1    1       3.46                    3.46
    1    2       5.43                    5.41
    2    1       4.26                    4.26
    2    2       6.51                    6.50
    4    1       5.82                    5.82
    4    2      11.36   <- +17 % regr.   9.68   <- back to reference
    8    1       9.25                    9.25
    8    2      13.95                   13.95
    16   1      15.06                   15.06
    16   2      21.50  (wide_vec)       21.50  (wide_vec)

Correctness: 30 / 30 test_gdn_decode_bf16_state_mtp_kernel configs pass at
B in {1, 2, 4, 8, 16}, T in {2, 4, 8}, cache_intermediate in {T, F} via
the official _test_gdn_decode_bf16_state_mtp_kernel helper. No test changes
needed — the existing test exercises the dispatcher and both ILP paths.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Follow-up to PR flashinfer-ai#3127. The shipped wide_vec kernel has a fixed tile_v=128
so its grid is only B*HV CTAs — at B <= 8 (HV=64) that's <= 512 CTAs on a
148-SM B200 (< 3.5 waves), and the dispatcher falls back to the baseline
ILP=4/8 path (LDG.E.64). This commit parameterizes tile_v so the grid gains
a V-tile dimension, letting the wide_vec kernel (LDG.E.128 on H) reach down
into small-batch sizes that were previously starved of SM parallelism.

Changes:

1. gdn_decode_bf16_state_wide_vec.py — tile_v is now a per-call constexpr:
   - Kernel decodes (i_n, i_hv, i_v) from block_idx; v_base offset by
     `i_v * tile_v`. ROWS_PER_GROUP and ITERS_PER_GROUP promoted from
     module-level constants to per-kernel constexprs.
   - Launcher computes grid = B * HV * (V / tile_v) and passes tile_v +
     num_v_tiles to the kernel.
   - Public wrapper `gated_delta_rule_mtp_wide_vec` accepts `tile_v: int = 128`
     (keyword-only default; positional callers unchanged).
   - Valid tile_v values: {32, 64, 128}. Asserted at the launcher entry.
   - Subgroup layout unchanged (16 threads x 8 BF16 per thread = LDG.E.128),
     so every tile_v variant uses the same per-thread memory-op widths.
   - Bit-exact with the previous fixed-tile_v=128 kernel at tile_v=128
     (verified: max_abs_diff = 0.0 over 252 pytest configs).

2. gdn_decode_bf16_state.py — new `_select_wide_vec_tile_v(B, HV, V=128)`
   helper picks tile_v by work_units = B * HV:
     work_units >= 1024 -> tile_v = 128  (unchanged: B >= 16 at HV=64)
     work_units >=  512 -> tile_v =  64  (new:   B =  8 at HV=64)
     work_units >=  128 -> tile_v =  32  (new:   B =  2..4 at HV=64)
     below              -> None (baseline ILP=4/8 as before)
   `_WIDE_VEC_WORK_UNITS_THRESHOLD` lowered 1024 -> 128 (module-scope
   constant kept for external benchmarks that monkey-patch it; actual
   dispatch uses the new picker). Dispatcher in `gated_delta_rule_mtp`
   now calls the picker and passes tile_v into `gated_delta_rule_mtp_wide_vec`.

3. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_wide_vec_mtp_kernel`
   parametrized over `tile_v in {32, 64, 128}`. num_v_heads restricted to
   {64} (Qwen3.5 production shape; HV=32 exploratory coverage is not needed
   for the small-batch path). Total configs: 9 batch x 7 T x 2 cache x 3 tile_v
   = 378.

Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI 5 warmup + 50 bench iters,
cache_intermediate=True, disable_state_update=True for T>=2; T=1 unchanged):

                post- 3127 before      post tile_v picker
    B    T       (us)                  (us)         speedup
    1    2       5.44                  5.46          1.00x  (picker -> tv=32 ties baseline)
    2    2       6.40                  6.34          1.01x  (picker -> tv=32)
    4    2       9.63                  8.14          1.18x  (picker -> tv=32)
    8    2      13.63                 12.74          1.07x  (picker -> tv=64)
    16   2      21.20                 21.18          1.00x  (tv=128, unchanged)
    32   2      37.81                 37.89          1.00x  (tv=128, unchanged)
    64   2      70.69                 70.69          1.00x  (tv=128, unchanged)

Biggest headline: B=4 T=2 goes from 9.63 us to 8.14 us (1.18x) by routing to
wide_vec tile_v=32, which was previously dispatched to the baseline ILP=4
kernel. Cumulative against origin/main at B=4 T=2: 11.14 us -> 8.14 us = 1.37x.

Correctness: 378 / 378 new pytest configs pass
(test_gdn_decode_bf16_state_wide_vec_mtp_kernel). Equivalent 1008-config
sanity run across tile_v = {32, 64, 128} x HV = {32, 64} also passed
during development (not in CI to keep the default matrix size reasonable).

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…l mode)

The T=1 ILP kernel (`gdn_decode_bf16state_ilp_kernel`) uses 4 warps × 32
threads with vec=4 BF16 per thread (LDG.E.64 / STG.E.64 on H state). NCU
captures at B=16/32 T=1 HV=64 show LSU pipe as the dominant bottleneck
(LSU 38.8–48.4 % vs DRAM 29.0–39.4 %), with L1/TEX at 67–75 %. The shipped
wide_vec kernel uses 8 groups × 16 threads / vec=8 BF16 (LDG.E.128 /
STG.E.128), which halves LSU instruction count and reduces L1 wavefronts.

wide_vec's SMEM-precompute phase runs `ceil(T / NUM_WARPS)` passes — at
T=1 that is 1 pass (only warp 0 does real work, others idle), so it
degenerates gracefully. Correctness: bit-exact at bf16 noise floor across
all 20 T=1 test configs (HV in {32, 64}, B in {1..512}, FP32 reference).

Changes:

1. gdn_decode_bf16_state.py::gated_delta_rule — new T=1 dispatch branch
   placed BEFORE the small-batch (`B < ILP_BATCH_THRESHOLD`) MTP redirect
   so B=8 HV=64 (work_units = 512) can also reach wide_vec at T=1 and get
   the LDG.E.128 win (measured 1.05x at B=8). Gated by ALL of:
     - K = V = 128 (wide_vec subgroup layout assumes this).
     - Pool mode: `initial_state_indices is not None`. Non-pool direct-state
       callers stay on the baseline ILP kernel, so the split-pool test path
       (baseline) remains bit-exact with the non-pool reference path that
       `test_output_state_indices` relies on.
     - Single-pool write: `output_state_indices is None` or identical to
       `initial_state_indices` (wide_vec has one indices tensor for R/W).
     - tile_v >= 64. At T=1 the wide_vec Phase 0 SMEM-precompute overhead is
       fixed per CTA while the main loop shrinks with tile_v. tile_v=32
       gives only 1 ILP iter per subgroup, insufficient to amortize
       Phase 0. Probe at HV=64 T=1 pool mode: tile_v=32 regresses at B=4
       (0.91x); tile_v=64 wins at B=8 (1.05x); tile_v=128 wins at B>=16
       (1.04-1.06x). So tile_v=32 is masked out at T=1 only.

2. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_t1_kernel`
   parametrization extended to also cover num_v_heads=64 (production
   Qwen3.5 shape). Total T=1 CI configs: 10 batch x 2 HV = 20.

Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, state-update ON, no cache,
pool mode with `initial_state_indices = arange(B)`, CUPTI 5 warmup + 50
bench iters). Baseline column monkey-patches `_select_wide_vec_tile_v`
to None so the ILP/MTP kernel path runs:

     B   baseline  wide_vec T=1  speedup
     1     3.49us        3.52us     0.99x  (MTP ILP=4; wide_vec gate n/a)
     2     4.29us        4.26us     1.01x  (MTP ILP=4; wide_vec gate n/a)
     4     5.76us        5.76us     1.00x  (MTP ILP=8; tile_v=32 masked)
     8     9.30us        8.83us     1.05x  (NEW: wide_vec tile_v=64)
    16    15.01us       14.35us     1.05x  (wide_vec tile_v=128)
    32    26.53us       25.30us     1.05x
    64    48.45us       46.53us     1.04x
   128    89.70us       86.00us     1.04x
   256   172.22us      165.02us     1.04x
   512   337.15us      323.10us     1.04x

Peak DRAM SOL at T=1 climbs from 80.6 % (previous ceiling) to 83.9 % at
B=512. All B >= 8 in pool mode gain 1.04-1.05x. Gains are smaller than
the T>=2 win (~1.10-1.50x) because the T=1 ILP kernel was already closer
to the DRAM roofline — LSU relief has less headroom to translate into
speedup.

Scope notes:
  - Non-pool (direct-state) T=1 callers: stay on baseline ILP. Routing
    them through wide_vec introduces 1-BF16-ULP output differences vs the
    baseline-served split-pool path, which fails test_output_state_indices
    at atol=1e-3. Since the test's semantic invariant is that split-pool
    and non-pool paths produce bit-identical output, both must use the
    same kernel; baseline ILP is the existing common denominator.
  - Split-pool writes (distinct output_state_indices tensor): stay on
    baseline ILP. wide_vec does not yet plumb h0_out_indices.
  - B*HV < 512: falls through to the (relocated) small-batch MTP redirect
    and baseline ILP=4/8 kernels. Wide_vec at tile_v=32 regresses here
    because Phase 0 overhead isn't amortized.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Drops gdn_decode_bf16state_cooprow_kernel (~270 lines) and its launcher
run_gdn_decode_bf16state_cooprow (~95 lines) from
flashinfer/gdn_kernels/gdn_decode_bf16_state.py. The cooprow path is
unreachable: gated_delta_rule routes to wide_vec / ILP / MTP-T1 only, and
the dispatch comment explicitly documents that cooprow had known
correctness issues at small batch.

Backward-compatible name aliases gated_delta_rule_bf16state_cooprow and
gated_delta_rule_bf16state_cooprow_mtp are kept (they alias to the live
gated_delta_rule and gated_delta_rule_mtp entry points, not to the
removed kernel) so external callers using those names continue to work.

Net delete: 376 lines.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The previous commit removed the cooprow kernel, which was the only
consumer of `cutlass.cute.nvgpu.cpasync`. Drop the now-unused import
to keep ruff happy.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The wide_vec module imported _reference_gdn_mtp from
gdn_decode_bf16_state_tma_horiz, but (a) the import was already marked
noqa: F401 (unused), and (b) gdn_decode_bf16_state_tma_horiz is a
working-tree-only scratch module that was never committed. With it
absent (e.g. on a fresh checkout) the wide_vec module fails to import,
which cascades into pytest collection failure for the GDN decode tests.

Drop the unused import. No behavior change; tests now load the wide_vec
module successfully.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The module's top-of-file docstring still described the cooprow path
(``Each warp processes ONE V-row at a time``, ``cp.async pipeline with
TILE_V=8 x TILE_K=128 tiles``) even though that kernel is gone. Replace
it with a current-state description that points at the live entry points
and the wide_vec / MTP fallback dispatch. Drop the module-level
constants that only the removed cooprow kernel used:

- TILE_V = 8, TILE_K = 128, NUM_STAGES = 2, NUM_THREADS = 128,
  NUM_BLOCKS_PER_STATE = 8 (cp.async pipeline tile sizes)
- MTP_TILE_K = 128 (never referenced; MTP kernels use MTP_NUM_THREADS /
  MTP_VEC_SIZE / MTP_ILP_ROWS only)
- ``_compiled_kernels: dict = {}`` (cooprow's compile cache; the live
  caches are ``_compiled_kernels_ilp`` and ``_compiled_kernels_mtp``)

No behavior change.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Both public entry points (``gated_delta_rule`` and ``gated_delta_rule_mtp``)
already assert ``K == 128 and V == 128`` at the top of the function, so
``_select_wide_vec_tile_v`` cannot be reached with V != 128. Drop the
redundant ``if V != 128: return None`` and the ``V`` parameter; the
remaining ``K == 128`` checks at the call sites also become redundant
and are inlined out.

No behavior change.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The BF16 GDN decode kernels are now strictly pool-mode: each batch
element reads/writes a slot in a shared ``[pool_size, HV, V, K]`` state
pool, indexed by ``initial_state_indices``. Production serving frameworks
(sglang, vllm) already use pool mode exclusively; non-pool mode was only
exercised by tests.

Changes:

- ``gated_delta_rule()`` and ``gated_delta_rule_mtp()`` in
  ``flashinfer/gdn_kernels/gdn_decode_bf16_state.py`` hard-assert
  ``initial_state_indices is not None``. The dispatchers no longer
  branch on ``use_pool``, no longer compute ``pool_size = B`` for the
  non-pool case, and no longer fall back to a synthesized
  ``cache["default_indices"]`` at runtime. The dummy ``arange(B)``
  kept inside the cached kernel state remains, but only as a
  cute.compile() shape template.

- ``gated_delta_rule_decode_pretranspose()`` in
  ``flashinfer/gdn_decode.py`` keeps its public ``use_pool`` /
  ``state`` API unchanged. When a caller hits the BF16 branch with
  non-pool semantics (passing ``state`` instead of ``initial_state``),
  the wrapper internally treats ``state`` as a pool of size B and
  synthesizes sequential indices ``arange(B)`` before calling the
  BF16 kernel. The math is identical (``pool_size = B``, sequential
  access), so no behavior change for non-pool callers (tests, etc.).

The FP32 ``gdn_decode_pretranspose.py`` kernel is unaffected — it
continues to support both pool and non-pool natively.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Removes ``gdn_decode_bf16state_ilp_kernel`` (the standalone T=1 ILP=8
baseline) and its launcher / cache / dispatch helpers. After non-pool
support was dropped, the T=1 ILP kernel was only reachable in a narrow
edge case (HV<32 with B>=ILP_BATCH_THRESHOLD=16, where wide_vec gates
out at tile_v<64) — i.e. never fires on Qwen3.5 (HV=64) or any other
production GDN shape. The MTP T=1 path (mtp_kernel ILP=8 / mtp_ilp4
ILP=4) covers the same shapes correctly.

Removed:
- ``gdn_decode_bf16state_ilp_kernel`` @cute.kernel and surrounding
  section header (~740 LOC)
- ``run_gdn_decode_bf16state_ilp`` @cute.jit launcher (~75 LOC)
- ``_compiled_kernels_ilp`` cache, ``ILP_BATCH_THRESHOLD`` constant
- ``_select_tile_v_for_batch`` helper (only used by ILP)
- ``TILE_V_ILP``, ``TILE_K_ILP``, ``NUM_THREADS_ILP``, ``VEC_SIZE_ILP``,
  ``ILP_ROWS`` module-level constants

Dispatcher change in ``gated_delta_rule()``: when wide_vec doesn't
fire (split-pool, or B*HV too small at T=1), unconditionally redirects
to ``gated_delta_rule_mtp`` instead of branching on ``B <
ILP_BATCH_THRESHOLD``. The MTP path picks ILP=4 vs ILP=8 via
``_get_bf16_mtp_config`` based on work_units.

Net: ``gdn_decode_bf16_state.py`` shrinks from 3046 to 2295 lines
(~1011 deletions, 40 insertions; mostly the ILP kernel body).

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…rnels

After the BF16 GDN kernel layer became pool-only, three direct-call
kernel tests still passed an ``[B, HV, V, K]`` state tensor without
``initial_state_indices`` and tripped the new pool-mode assertion.

Fix by treating the state tensor as a pool of size B and passing
sequential indices ``arange(B)``. The math is identical to the previous
non-pool semantics (pool_size = B, sequential access), so coverage of
the underlying kernel behavior is preserved.

Updated tests:
- ``_test_gdn_decode_bf16_state_kernel`` (used by
  ``test_gdn_decode_bf16_state_kernel``)
- ``_test_gdn_decode_bf16_state_t1_kernel`` (used by
  ``test_gdn_decode_bf16_state_t1_kernel``)
- ``test_pretranspose_api_uses_gdn_decode_bf16_state``

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The BF16 GDN kernels are pool-only post-cleanup. The bench harness was
calling the kernel without indices, which (a) errored at T=1 with
--update-state once the pool-only assertion was added, and (b) had no
way to drive the split-pool dispatch at all.

Adds:
- ``--pool-mode {single,split}`` argparse flag (default: single)
- ``bench_gdn_decode_bf16_state`` accepts ``pool_mode``: allocates a
  ``[2*B, HV, V, K]`` pool under split, synthesizes
  ``output_state_indices = arange(B, 2*B)`` so write slots are distinct
  from read slots; under single it allocates ``[B, HV, V, K]`` and
  passes ``output_state_indices=None`` (read==write).
- ``gdn_decode_bf16_state_wrapper`` plumbs ``output_state_indices``
  through to both the T=1 and T>1 entry points; also passes
  ``initial_state_indices`` at T=1 (was previously omitted, which is
  what caused the ``--update-state`` regression).
- ``intermediate_states_buffer`` stays sized to ``[B, T, HV, V, K]``
  in both modes — the kernel keys it by READ index ``cache_idx`` and
  read indices are ``arange(B)`` regardless of pool mode.

Fixes the post-cleanup ``--update-state`` regression and unlocks
split-pool benchmarking. No kernel changes; baseline numbers should
match prior runs for single-pool, and split-pool exercises the
existing ``gdn_decode_bf16state_mtp_kernel`` (ILP=8) fallback path.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Wide_vec was previously single-pool only — at the dispatcher we routed
split-pool callers (output_state_indices != initial_state_indices) to
the slower mtp_kernel ILP=8 fallback. Per the apr27 baseline run on
B200 HV=64 K=V=128, that costs 5.9-14.9% across the marquee shapes:

  Shape       single (wide_vec)  split (mtp_kernel)  overhead
  T=1 B=8     8.67 us            9.18 us             +5.9%
  T=1 B=32    25.22 us           28.99 us            +14.9%
  T=2 B=8     12.72 us           13.76 us            +8.2%
  T=2 B=32    38.06 us           42.58 us            +11.9%
  T=4 B=8     18.67 us           21.04 us            +12.7%
  T=4 B=32    62.29 us           69.20 us            +11.1%

Adds:
- ``h0_out_indices: cute.Tensor`` parameter to ``gdn_wide_vec_kernel``
  (and ``_run_wide_vec`` launcher).
- ``write_cache_idx`` derived from ``h0_out_indices[i_n]`` with the
  same negative-index null-buffer redirect as the read side.
- Pre-computed write-side tiles ``ht_w0..ht_w3`` at
  ``flat_write_state_idx = write_cache_idx * HV + i_hv``. The final
  state writeback now stores via these write tiles. ``cute.local_tile``
  is metadata-only so constructing both views every iteration costs
  nothing when single-pool callers reuse the same indices.
- Intermediate-state cache continues to key by READ index (cache_idx);
  it represents per-request cached state and is independent of where
  the final committed state lives.
- Python entry ``gated_delta_rule_mtp_wide_vec`` now accepts
  ``output_state_indices`` and forwards it to the kernel. When None,
  defaults to ``initial_state_indices`` so single-pool callers don't
  need to change.
- Dispatcher in ``gated_delta_rule`` and ``gated_delta_rule_mtp``
  forwards ``output_state_indices`` and drops the
  ``output_state_indices is None`` gate that previously routed split
  callers to the ILP=8 fallback.

Verified on B200 with a B=16 T=4 HV=64 micro-test: split-pool produces
bit-identical output to single-pool, the read slots [0..B) are
preserved, the write slots [B..2B) hold the same final state, and
single-pool still updates the read slots in place.

The legacy ``gdn_decode_bf16state_mtp_kernel`` is still reachable when
wide_vec gates out (work_units<128 at T>=2 with split-pool) — the next
commit extends ``gdn_decode_bf16state_mtp_ilp4_kernel`` to cover that
case so mtp_kernel can be removed.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…rnel

Previously the ILP=4 kernel had a minimal signature (single ``h0_indices``
for both read and write), and the dispatcher in ``gated_delta_rule_mtp``
forced ILP=8 whenever ``output_state_indices`` was non-None. This left
the work_units<128 split-pool case (B=1 at HV=64) on the slower ILP=8
path even though ILP=4 has higher occupancy at that shape.

Mirrors the wide_vec split-pool change:
- Adds ``h0_out_indices`` parameter to ``gdn_decode_bf16state_mtp_ilp4_kernel``
  and the ``run_gdn_decode_bf16state_mtp_ilp4`` launcher.
- Computes ``write_cache_idx`` with the same ``< 0 -> slot 0`` redirect.
- Builds write-side tile views ``hta_w..htd_w`` at
  ``flat_write_state_idx`` and uses them in the final-state writeback.
- Dispatcher in ``gated_delta_rule_mtp`` drops the
  ``if output_state_indices is None`` branch that forced ILP=8 — the
  config picker is now pool-mode-agnostic.
- Updates the cache build / runtime call to plumb ``h0_out_indices`` /
  ``output_state_indices``.

The intermediate-state cache continues to key by READ index
(cache_idx), matching wide_vec's semantics (per-request cache,
independent of write destination).

Verified on B200 with a B=1 T=4 HV=64 micro-test (the smallest
work_units shape, exercises the ilp4 path): split-pool produces
bit-identical output to single-pool, read slot preserved, write slot
matches single-pool result, single-pool still in-place updates.

After this commit, ``gdn_decode_bf16state_mtp_kernel`` (ILP=8) is no
longer reachable from any shape — every split-pool / single-pool
combination at every (B, T) routes to wide_vec or mtp_ilp4. The next
commit removes mtp_kernel and its launcher.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After wide_vec gained split-pool support and mtp_ilp4 likewise,
``gdn_decode_bf16state_mtp_kernel`` (ILP=8) is no longer reachable
from any dispatch path. The remaining MTP-fallback shapes (work_units
< 128, T=1 small batch with tile_v < 64) all benefit from ILP=4's
higher occupancy, so ``_get_bf16_mtp_config`` collapses to always
returning ``ilp_rows = 4``.

Removed:
- ``gdn_decode_bf16state_mtp_kernel`` @cute.kernel (~924 LOC) and its
  surrounding section header.
- ``run_gdn_decode_bf16state_mtp`` @cute.jit launcher (~95 LOC) and
  surrounding section header.
- The ILP=8 branch of ``_get_bf16_mtp_config`` (T=2-stall heuristic
  is unnecessary now — wide_vec covers T=2 work_units >= 128 and
  ILP=4 covers the rest).
- The ``MTP_ILP_ROWS = 8`` constant (only the removed kernel used it).
- The ``ilp_rows == 4 / else`` branching in the MTP cache build and
  runtime call sites — both branches now go to mtp_ilp4.

Updated:
- Module top-of-file docstring: drops mtp_kernel from the dispatch
  list and notes that split-pool is supported natively.
- ``_select_tile_v_for_mtp`` docstring: tile_v multiple of
  ``4 * MTP_ILP4_ROWS`` (= 16), not the old ``MTP_ILP_ROWS * 4``.

Verified on B200 with a 7-shape sweep covering each dispatch path
(B in {1, 4, 8, 32} x T in {1, 4}): output bit-identical between
single-pool and split-pool, split-pool read slots stay pristine.

Net: ``gdn_decode_bf16_state.py`` shrinks from 2295 to 1221 lines
(-1074 LOC, ~47% smaller). The BF16 surface is now exactly two
@cute.kernel definitions: ``gdn_wide_vec_kernel`` (the fast path)
and ``gdn_decode_bf16state_mtp_ilp4_kernel`` (the small-batch
fallback) — both pool-only, both split-pool capable.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Adds ``test_gdn_decode_bf16_state_mtp_split_pool`` (12 cases) which
calls ``gated_delta_rule_mtp`` directly with split-pool semantics
(``output_state_indices != initial_state_indices``) and verifies:

  - Output is bit-equivalent (within bf16 noise) to the single-pool
    dispatch on the same inputs.
  - Read slots [0..B) stay pristine in split mode.
  - When state-update is enabled (cache_intermediate_states=False),
    the split write slots [B..2B) hold the same final state that the
    single-pool dispatch wrote into the read slots.

Sweeps three batch sizes that hit each kernel:
  - B=1   -> mtp_ilp4 (work_units=64 < 128, wide_vec gates out)
  - B=8   -> wide_vec (work_units=512)
  - B=32  -> wide_vec (work_units=2048)

Plus T in {2, 4} and cache_intermediate_states in {True, False}.
Complements the existing ``test_output_state_indices`` (which
exercises T=1 split-pool through the wrapper) by covering the T>=2
direct-MTP path.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…BF16 file

Now that the BF16 surface is exactly two ``@cute.kernel``s
(``gdn_wide_vec_kernel`` and ``gdn_decode_bf16state_mtp_ilp4_kernel``),
keeping them in two separate Python files added more friction than
isolation. Both call into the same dispatcher, share the FMA helpers,
and need the same ``NUM_SMS`` / ``_USE_PACKED_FMA`` runtime constants.

Merges the wide_vec file into ``flashinfer/gdn_kernels/gdn_decode_bf16_state.py``:

- Wide_vec layout constants (``LANES_PER_ROW``, ``ELEMS_PER_LANE``,
  ``NUM_WARPS``, ``NUM_THREADS``, ``NUM_GROUPS``, ``ILP_ROWS``) inlined
  next to the MTP constants. Drops the unused ``TILE_K = 128``
  documentation constant from the wide_vec block.
- ``gdn_wide_vec_kernel`` and its ``_run_wide_vec`` launcher land in
  their own section in the merged file. Section ordering is now:
  wide_vec kernel + launcher, then mtp_ilp4 kernel + launcher.
- ``_compiled_kernels_wide_vec`` cache lives next to
  ``_compiled_kernels_mtp``.
- ``gated_delta_rule_mtp_wide_vec`` Python entry sits between the
  helper functions and the public dispatchers ``gated_delta_rule`` /
  ``gated_delta_rule_mtp``. Drops the two cross-file lazy imports
  (``from .gdn_decode_bf16_state_wide_vec import …``) that the
  dispatchers used to do — same-module references now.
- Drops the duplicated ``fma_pair_mul`` / ``fma_pair`` definitions and
  ``NUM_SMS`` / ``_GPU_MAJOR`` / ``_USE_PACKED_FMA`` constants that
  appeared in both files; the merged file uses the existing main-file
  definitions.

Test ``test_gdn_decode_bf16_state_wide_vec_mtp_kernel`` updated to
import ``gated_delta_rule_mtp_wide_vec`` from the merged location.

The two old files (1221 + 799 = 2020 lines) become one file at 1973
lines (~47 LOC saved from dedup of imports, FMA helpers, runtime
constants, and section headers).

No behavior change. Full BF16 / wide_vec test sweep verified prior to
merge; cross-shape correctness (B in {1, 8, 32} × T in {2, 4})
re-checked on the merged file (single == split bit-equivalent at
bf16 noise floor). Pre-commit clean.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
… PR flashinfer-ai#3145)

The ``intermediate_states_buffer`` is BATCH-scoped — shape ``[B, T, HV, V, K]``
— but both the wide_vec kernel (line 1115) and the mtp_ilp4 kernel
(line 621) were indexing it with ``cache_idx * T * HV + i_t * HV + i_hv``
where ``cache_idx`` is the POOL slot from ``initial_state_indices[i_n]``.

When ``pool_size > B`` (every realistic serving config) and
``initial_state_indices`` points at slots ``>= B`` (e.g. middle of a
1024-slot pool while servicing a B=32 batch), ``cache_idx * T * HV``
exceeds the buffer's ``B * T * HV`` extent and the
``cute.local_tile`` write goes off the end of the cache buffer ->
``cudaErrorIllegalAddress`` or silent memory corruption.

This is the same bug upstream PR flashinfer-ai#3145 fixed in the now-removed
``gdn_decode_bf16state_mtp_kernel``; both surviving BF16 kernels
inherited the incorrect pattern.

Fix:
- ``gdn_decode_bf16state_mtp_ilp4_kernel``: ``flat_idx = i_n * T * HV + ...``
  (was ``cache_idx * T * HV + ...``).
- ``gdn_wide_vec_kernel``: same.
- Dispatcher (both ``gated_delta_rule_mtp`` and
  ``gated_delta_rule_mtp_wide_vec``): assert
  ``intermediate_states_buffer.shape[0] == B`` and reshape using ``B``
  rather than ``buffer_size``. Also updates the comment / docstring to
  call out batch-scoped semantics explicitly.

Adds ``test_gdn_decode_bf16_state_mtp_pool_larger_than_batch`` (12
cases) which parametrizes ``pool_size_multiplier in {1, 4}`` and
``batch_size in {1, 8, 32}`` and ``seq_len in {2, 4}`` so both the
ilp4 path (B=1) and the wide_vec path (B=8/32) are exercised with
pool indices pointing at the upper half of a 4*B-slot pool. Verified
the test catches the bug: re-introducing the
``cache_idx * T * HV`` form makes the test fail with
``cudaErrorIllegalAddress``; reverting the line makes it pass again.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Several comments still described the old dispatch landscape (mtp_kernel
ILP=8 path, ``B*HV >= 1024`` wide_vec gate, pool-scoped intermediate
buffer) even though those have been replaced this session. No code
changes.

Updated:
- ``gated_delta_rule`` docstring: no longer claims the MTP fallback
  picks "ILP=4 or ILP=8 via _get_bf16_mtp_config" (now always ILP=4).
  Tightened the split-pool note — wide_vec handles split-pool natively
  on either dispatch path.
- ``gated_delta_rule_mtp_wide_vec`` docstring: gate threshold is
  ``B*HV >= 128`` (tile_v=32) up through ``>= 1024`` (tile_v=128), not
  just ``>= 1024``.
- ``gated_delta_rule_mtp`` docstring: ``intermediate_states_buffer`` is
  ``[B, T, HV, V, K]``, batch-scoped (was incorrectly documented as
  ``[pool_size, T, HV, V, K]``). Calls out the OOB-fix invariant for
  future readers.
- Comment in ``gated_delta_rule_mtp`` dispatcher: drops "wide_vec
  hardcodes TILE_V=128 so we gate on K==V==128 too" — TILE_V is now
  configurable. Drops "mtp_kernel (ILP=8)" reference (kernel removed).
- Comment in ``gated_delta_rule`` T=1 fallback: drops "dispatches to
  mtp_kernel (ILP=8) or mtp_ilp4_kernel" (only mtp_ilp4 left).
- Comment in ``gated_delta_rule_mtp`` post-wide_vec branch: only the
  ILP=4 MTP path is reachable now.
- Kernel param comments for ``intermediate_states``: shape
  ``[B * T * HV, V, K]``, not ``[pool_size * T * HV, V, K]``.

Two ``ILP=8`` references remain by design (line 92: explains the
mtp_ilp4 design lineage relative to the original ILP=8 kernel; line
1512: explains why ILP=4's higher occupancy beats ILP=8 at the
small-batch fallback shapes wide_vec doesn't hit). Both are useful
historical context, not stale claims.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…s (same_pool Constexpr)

Recovers the T=1 large-B single-pool regression vs apr22 documented in
``results/reports_apr27/regression_analysis.md`` (+4.6 % to +6.7 %
across B in {16, 32, 64, 128, 256} at T=1 with state-update on, the
production Qwen3.5 decode shape).

Root cause: split-pool support added the ability to write to a
distinct pool slot via ``h0_out_indices``. The dispatcher passes
identical read/write indices in single-pool mode, so the stores go
to the same slot they used to — but nvcc cannot prove this at compile
time (``output_state_indices is initial_state_indices`` is a runtime
identity), so the SASS materialises the extra LDG.32 +
negative-redirect compare + IMAD + 4×local_tile base-pointer ops per
V-iter even when they're guaranteed-equivalent to the read-side
arithmetic. In a DRAM-saturated kernel (T=1 B>=16 hits 74 % of peak
DRAM) this stretches the critical path by ~5-7 %.

Fix: a ``Constexpr[bool] same_pool`` kernel parameter, set by the
dispatcher to ``output_state_indices is None or output_state_indices
is initial_state_indices``. When True, the kernel aliases write-side
state to the read side at compile time:

  - ``write_cache_idx = cache_idx``
  - ``flat_write_state_idx = flat_state_idx``
  - ``ht_w* = ht*`` (write tiles alias read tiles)

nvcc DCEs the dead branch in the same_pool=True compile path; SASS
matches apr22's pre-split-pool kernel exactly. When False
(split-pool), the kernel uses the existing distinct write-side
machinery — bit-identical behaviour.

Applied symmetrically to both BF16 kernels:
- ``gdn_wide_vec_kernel`` (the production hot path)
- ``gdn_decode_bf16state_mtp_ilp4_kernel`` (the small-batch fallback)

Public API is unchanged: callers (sglang, etc.) pass ``output_state_indices``
or omit it as before; the dispatcher computes ``same_pool`` and selects
the compile variant. JIT cache footprint roughly doubles per shape that
exercises both pool modes — for sglang's typical single-pool-only call
pattern, only the same_pool=True variant ever compiles.

Verified post-fix on B200: bench in single-pool, --warmup 5 --iters 50,
T=1 with --update-state, T>=2 with --cache-intermediate-states.
Numbers in microseconds; Δ is post-fix vs the apr22 prototype reference
at commit e7bee6ad of the same branch.

Full BS × T grid (single-pool, post-fix vs apr22):

      |  T=1   |  T=2   |  T=3   |  T=4   |  T=5   |  T=6   |  T=7   |  T=8
  ----+--------+--------+--------+--------+--------+--------+--------+--------
  apr22 prototype (us):
   B=1   3.52    5.42    5.86    6.69    8.83    9.79   10.58   11.39
   B=4   5.79    8.13    9.79   11.39   14.53   16.10   17.89   19.55
   B=8   9.28   12.67   15.90   18.91   23.55   26.42   30.59   33.95
   B=16 14.30   21.39   27.20   33.36   40.93   47.79   54.53   61.20
   B=32 25.30   38.13   49.94   62.08   77.62   89.94  103.22  117.01
   B=64 46.53   70.85   94.34  117.62  146.05  171.41  197.98  222.24
   B=128 86.00 134.91  180.27  225.54  278.67  328.54  379.52  427.78
   B=256 165.02 262.70 350.88  440.14  543.20  640.32  741.31  837.17

  post-fix (us):
   B=1   3.39    5.28    5.70    6.53    8.82    9.63   10.40   11.30
   B=4   5.98    7.87    9.62   11.26   14.48   16.22   17.76   19.44
   B=8   8.54   12.48   15.49   18.75   23.17   26.27   29.74   33.33
   B=16 14.34   21.55   27.15   33.87   41.38   48.67   55.50   63.14
   B=32 25.15   38.08   50.03   62.40   77.63   91.39  104.83  118.78
   B=64 46.50   70.88   93.92  118.18  150.19  175.70  200.34  230.40
   B=128 85.90 134.83  180.61  226.81  288.00  340.53  391.90  443.36
   B=256 164.98 261.76 351.49  441.97  559.82  665.23  764.11  868.05

  Δ post-fix vs apr22 (%):
   B=1  -3.7%   -2.6%   -2.7%   -2.4%   -0.1%   -1.6%   -1.7%   -0.8%
   B=4  +3.3%   -3.2%   -1.7%   -1.1%   -0.3%   +0.7%   -0.7%   -0.6%
   B=8  -8.0%   -1.5%   -2.6%   -0.8%   -1.6%   -0.6%   -2.8%   -1.8%
   B=16 +0.3%   +0.7%   -0.2%   +1.5%   +1.1%   +1.8%   +1.8%   +3.2%
   B=32 -0.6%   -0.1%   +0.2%   +0.5%    0.0%   +1.6%   +1.6%   +1.5%
   B=64 -0.1%    0.0%   -0.4%   +0.5%   +2.8%   +2.5%   +1.2%   +3.7%
   B=128 -0.1%  -0.1%   +0.2%   +0.6%   +3.3%   +3.7%   +3.3%   +3.6%
   B=256 0.0%   -0.4%   +0.2%   +0.4%   +3.1%   +3.9%   +3.1%   +3.7%

Headline: the T=1 regression is recovered. **All T=1 cells now within
+/- 3.3 % of apr22**, and the previously-regressed T=1 B>=16 cells
(was +4.6 % .. +6.7 %) are within +/- 1.5 %.

T=2/T=3/T=4 stay clean across the full B sweep (within +/- 3.2 %).

T=5..T=8 large-B residual (+2.5 % .. +3.9 % at B>=64) is unchanged by
this fix and not from the write-side machinery: with cache=ON the
writeback block is compile-time disabled, nvcc was already DCE'ing
the write-tile views, and ``same_pool`` has nothing extra to elide.
Likely from the kernel-entry LDG.32 of ``h0_out_indices`` (always
runs, regardless of caching). Defer to a follow-up if it matters;
production hot path (T=1) is recovered.

Split-pool path verified within +/- 0.2 % of pre-fix split-pool
numbers across (B, T) in {(8, 1), (8, 4), (32, 1), (32, 4), (256, 1),
(256, 4)} — split callers still hit the explicit write-side code via
the same_pool=False compile variant.

Bit-equivalence check across 8 (B, T) shapes (B in {1, 8, 32, 256} x
T in {1, 4}): single-pool output == split-pool output to bf16 noise
floor in every case.

Full pytest sweep: 513 passed, 0 failed in 18m18s
(``test_gdn_decode_bf16_state*``, ``test_decode_kernel_pretranspose_pool*``,
``test_output_state_indices*``).

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub marked this pull request as ready for review April 28, 2026 06:43
@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !613 has been updated with latest changes, and the CI pipeline #49720223 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/conftest.py`:
- Around line 143-149: The traversal that checks for CUDA OOM only follows one
exception link with "exc = exc.__cause__ or exc.__context__", so wrapped OOMs on
the other branch can be missed; update the logic in tests/conftest.py (the loop
that inspects "exc" and calls is_cuda_oom_error_str and checks isinstance(...,
torch.cuda.OutOfMemoryError)) to traverse both __cause__ and __context__
branches (e.g., use a stack/queue or set to visit exceptions
breadth/depth-first, pushing exc.__cause__ and exc.__context__ when present)
until all linked exceptions are inspected and return True if any is an OOM.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 34a6f2f1-b813-434d-bf58-e5d66218c68a

📥 Commits

Reviewing files that changed from the base of the PR and between 7266b5e and ec76b2aac8ec8b2c7428af8ac1c68bd0453bf85e.

📒 Files selected for processing (1)
  • tests/conftest.py

Comment thread tests/conftest.py Outdated
Comment on lines +143 to +149
while exc is not None:
if isinstance(exc, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(
str(exc)
):
return True
exc = exc.__cause__ or exc.__context__
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
def old_walk(exc):
    while exc is not None:
        if "out of memory" in str(exc).lower() and "cuda" in str(exc).lower():
            return True
        exc = exc.__cause__ or exc.__context__
    return False

def new_walk(exc):
    seen, stack = set(), [exc]
    while stack:
        cur = stack.pop()
        if id(cur) in seen:
            continue
        seen.add(id(cur))
        s = str(cur).lower()
        if "cuda" in s and "out of memory" in s:
            return True
        if cur.__cause__ is not None:
            stack.append(cur.__cause__)
        if cur.__context__ is not None:
            stack.append(cur.__context__)
    return False

ctx = RuntimeError("CUDA out of memory while formatting tensor diff")
cause = RuntimeError("wrapper cause without oom text")
top = RuntimeError("top-level runtime")
top.__cause__ = cause
top.__context__ = ctx

print("old_walk:", old_walk(top))  # expected False (missed)
print("new_walk:", new_walk(top))  # expected True (detected)
PY

Repository: flashinfer-ai/flashinfer

Length of output: 99


Traverse both __cause__ and __context__ branches to detect OOM in wrapped exceptions.

At line 148, the code uses exc = exc.__cause__ or exc.__context__, which follows only one link. When both are present, an OOM error in the non-selected branch will be missed, causing a test to fail instead of skip. The exception chain from torch.testing.assert_close can wrap inner OOM errors in multiple branches, so both must be examined.

Proposed fix
 def _is_oom_in_chain(exc: BaseException) -> bool:
-    # torch.testing.assert_close wraps the inner OOM in a RuntimeError("Comparing ...")
-    # whose str() does not mention "out of memory" — only __cause__ does.
-    while exc is not None:
-        if isinstance(exc, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(
-            str(exc)
-        ):
-            return True
-        exc = exc.__cause__ or exc.__context__
+    # torch.testing.assert_close may wrap inner exceptions; inspect the full chain.
+    seen: set[int] = set()
+    stack: list[BaseException] = [exc]
+    while stack:
+        cur = stack.pop()
+        cur_id = id(cur)
+        if cur_id in seen:
+            continue
+        seen.add(cur_id)
+        if isinstance(cur, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(
+            str(cur)
+        ):
+            return True
+        if cur.__cause__ is not None:
+            stack.append(cur.__cause__)
+        if cur.__context__ is not None:
+            stack.append(cur.__context__)
     return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/conftest.py` around lines 143 - 149, The traversal that checks for CUDA
OOM only follows one exception link with "exc = exc.__cause__ or
exc.__context__", so wrapped OOMs on the other branch can be missed; update the
logic in tests/conftest.py (the loop that inspects "exc" and calls
is_cuda_oom_error_str and checks isinstance(..., torch.cuda.OutOfMemoryError))
to traverse both __cause__ and __context__ branches (e.g., use a stack/queue or
set to visit exceptions breadth/depth-first, pushing exc.__cause__ and
exc.__context__ when present) until all linked exceptions are inspected and
return True if any is an OOM.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !613 has been updated with latest changes, and the CI pipeline #49751839 is currently running. I'll report back once the pipeline job completes.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !613 has been updated with latest changes, and the CI pipeline #49882216 is currently running. I'll report back once the pipeline job completes.

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !613 has been updated with latest changes, and the CI pipeline #49978944 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@kahyunnam kahyunnam enabled auto-merge (squash) May 1, 2026 18:49
@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !613 has been updated with latest changes, and the CI pipeline #50042869 is currently running. I'll report back once the pipeline job completes.

The wide_vec MTP test materializes two full-tensor float32 upcasts of
[B, T, HV, V, K] BF16 in torch.testing.assert_close. At B=128, T=8 (or
B=256, T>=2) each upcast is ~2 GiB and torch.isclose OOMs. PyTorch
re-wraps the OutOfMemoryError as RuntimeError("Comparing\n\n..."), whose
str() lacks "CUDA"/"out of memory", so conftest's OOM-skip filter misses
it and the test reports as a failure rather than a skip.

Chunk the assertion along the batch axis (~256 MiB float32 peak per
chunk). _test_gdn_decode_bf16_state_mtp_kernel is shared with the
non-wide_vec MTP test, so both benefit.

Fixes failures on RTX 5090 (32 GiB):
  test_gdn_decode_bf16_state_wide_vec_mtp_kernel[
    bfloat16-128-16-16-64-128-{7,8}-True-32]
  test_gdn_decode_bf16_state_wide_vec_mtp_kernel[
    bfloat16-256-16-16-64-128-{2,3}-True-32]

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
auto-merge was automatically disabled May 3, 2026 03:53

Head branch was pushed to by a user without write access

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !613 has been updated with latest changes, and the CI pipeline #50118003 is currently running. I'll report back once the pipeline job completes.

@kahyunnam kahyunnam enabled auto-merge (squash) May 4, 2026 17:25
@kahyunnam kahyunnam merged commit 0739df3 into flashinfer-ai:main May 6, 2026
50 of 78 checks passed
vadiklyutiy added a commit to vadiklyutiy/flashinfer that referenced this pull request May 6, 2026
…-offset overflow

Same root cause as the original report — the CuTe-DSL pool+indices kernels
compute the per-slot element offset (cache_idx * HV + i_hv) * stride[0]
(or pool_idx * stride[0] for fp32) using Int32, so once the product
exceeds INT32_MAX the multiplication wraps to a negative offset and the
load/store hits an unmapped global address (CUDA error: an illegal
memory access was encountered). Discovered while integrating
gated_delta_rule_decode_pretranspose into vLLM's GDN decode path for
Qwen3.5-class models.

Backends fixed (HV=32, V=K=128 unless noted):

  fp32 pretranspose:
    gdn_decode_kernel_{small,big}_batch_pretranspose
      pool_idx * stride[0]                  overflows at pool_idx >= 3972
                                            (vLLM padded slot stride 540672)

  bf16 small-batch fallback (B*HV <= 128):
    gdn_decode_bf16state_mtp_ilp4_kernel
      (cache_idx * HV + i_hv) * V * K       overflows at cache_idx >= 4096

  bf16 production fast path
  (B*HV >= 128 at T>=2, B*HV >= 512 at T=1):
    gdn_wide_vec_kernel
      (cache_idx * HV + i_hv) * V * K       overflows at cache_idx >= 4096
      (i_n * T * HV + i_t * HV + i_hv)      overflows at B*T*HV >= 131072
        * V * K   (intermediate-states      e.g. B>=256 at T=8 HV=64
         cache writeback)                   with cache_intermediate_states=True

Fix: widen pool_idx / out_pool_idx (pretranspose) and cache_idx /
write_cache_idx (bf16) to Int64 immediately after they are read from
the indices tensors; widen the per-call batch index used for the
intermediate-states cache write (i_n) to Int64 in both bf16 kernels so
that flat_idx = i_n * T * HV + i_t * HV + i_hv inherits Int64 even when
B * T * HV crosses 131072. The downstream flat_state_idx,
flat_write_state_idx and flat_idx all stay Int64, so the offset
multiplications inside cute.local_tile / h0_source[(...)] cannot wrap.

Note on PR scope vs. PR flashinfer-ai#3147 (Ameyn/wide vec t1, merged on main after
this branch was cut): flashinfer-ai#3147 deleted gdn_decode_bf16state_mtp_kernel and
replaced the bf16 surface with gdn_decode_bf16state_mtp_ilp4_kernel
(small-batch fallback) + gdn_wide_vec_kernel (production hot path).
Both surviving kernels reproduce the same Int32 cache_idx offset bug as
the deleted kernel, and gdn_wide_vec_kernel additionally has an Int32
flat_idx offset bug on the intermediate-states writeback that's
reachable at production batch sizes (B >= 256 with MTP T=8, HV=64).
This commit fixes all three sites at once.

Regression test (tests/gdn/test_decode_pretranspose_noncontiguous_pool.py)
covers:
  * test_decode_pretranspose_pool_int64_offset[3973, 8191] — fp32
    vLLM-padded pool (~8.6 / 17.7 GB) at the pool_idx threshold.
  * test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196] — bf16
    contiguous pool (~4.3 GB) at the cache_idx threshold; at B=1 HV=32
    the dispatcher selects gdn_decode_bf16state_mtp_ilp4_kernel, which
    has the identical overflow site and identical Int64 fix as
    gdn_wide_vec_kernel. The wide_vec / intermediate-states fixes are
    structurally identical to the ilp4 fix exercised here.

Both bf16 / fp32 tests compare the pool path against a gather + direct-state
reference (numerical correctness, not just non-crashing) and assert the
in-place state update matches. VRAM-based skip when free memory is
insufficient.

AI-assisted (Cursor / Claude).

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants