Skip to content

optimize gdn decode bf16 state kernel for mtp with caching. #3127

Draft
ameynaik-hub wants to merge 4 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/optimize_gdn_bf16_state_mtp_apr
Draft

optimize gdn decode bf16 state kernel for mtp with caching. #3127
ameynaik-hub wants to merge 4 commits intoflashinfer-ai:mainfrom
ameynaik-hub:ameyn/optimize_gdn_bf16_state_mtp_apr

Conversation

@ameynaik-hub
Copy link
Copy Markdown
Contributor

@ameynaik-hub ameynaik-hub commented Apr 20, 2026

📌 Description

GDN BF16-state MTP — dispatcher output (B200)

Config: HV=64, H_Q=H_K=16, K=V=128, BF16, qk_l2norm=ON. T=1 uses state-update ON (no intermediate caching); T≥2 uses intermediate caching ON with state-update OFF. Measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state

Batch T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 3.49 5.42 5.79 6.69 8.80 9.76 10.59 11.55
2 4.26 6.43 7.14 8.24 10.64 11.90 12.99 14.14
4 5.79 9.68 10.48 12.51 16.05 18.11 20.22 22.34
8 9.18 13.70 17.04 21.12 26.08 29.95 34.03 37.86
16 15.01 21.18 27.23 33.50 41.30 48.10 55.30 61.97
32 26.74 37.89 50.14 62.64 78.45 91.31 103.82 117.71
64 48.22 70.69 93.87 118.02 147.38 171.71 199.01 223.98
128 89.38 135.20 180.32 225.92 279.50 328.64 379.30 428.86
256 172.18 262.91 351.06 440.62 543.81 641.14 740.19 838.14
512 337.18 516.48 691.10 868.46 ERROR ERROR ERROR ERROR

DRAM SOL % (of 8.0 TB/s B200 peak)

Batch T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 15.2% 14.7% 18.4% 19.9% 18.2% 19.1% 20.1% 20.8%
2 24.9% 24.8% 29.8% 32.3% 30.0% 31.3% 32.8% 33.9%
4 36.6% 32.9% 40.6% 42.6% 39.8% 41.2% 42.2% 43.0%
8 46.1% 46.5% 50.0% 50.4% 49.0% 49.8% 50.1% 50.7%
16 56.4% 60.2% 62.5% 63.6% 61.9% 62.1% 61.7% 62.0%
32 63.4% 67.3% 67.9% 68.0% 65.2% 65.4% 65.7% 65.3%
64 70.3% 72.1% 72.5% 72.2% 69.4% 69.5% 68.6% 68.6%
128 75.8% 75.4% 75.5% 75.4% 73.2% 72.7% 72.0% 71.6%
256 78.7% 77.6% 77.6% 77.3% 75.3% 74.5% 73.8% 73.3%
512 80.4% 79.0% 78.8% 78.5% ERROR ERROR ERROR ERROR

Speedup vs origin/main (8559397)

Batch T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
1 2.14× 1.06× 1.08× 1.06× 1.09× 1.07× 1.08× 1.06×
2 1.95× 1.04× 1.04× 1.06× 1.05× 1.03× 1.05× 1.05×
4 1.65× 1.15× 1.00× 1.00× 1.00× 1.00× 1.00× 0.99×
8 1.42× 1.01× 1.00× 1.00× 1.00× 1.00× 1.00× 1.01×
16 1.03× 1.10× 1.11× 1.13× 1.13× 1.13× 1.12× 1.12×
32 1.04× 1.13× 1.08× 1.10× 1.09× 1.09× 1.10× 1.09×
64 1.04× 1.11× 1.08× 1.09× 1.08× 1.09× 1.08× 1.09×
128 1.02× 1.10× 1.07× 1.09× 1.09× 1.10× 1.10× 1.10×
256 1.01× 1.10× 1.07× 1.09× 1.10× 1.10× 1.10× 1.11×
512 1.00× 1.10× 1.07× 1.10× ERROR ERROR ERROR ERROR

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

ameynaik-hub and others added 2 commits April 18, 2026 19:30
…to ILP kernel

Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:

1) Per-call Python overhead fix
   Move torch.arange / torch.empty / from_dlpack / get_device_capability out
   of the steady-state call path. These were previously called on every
   invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
   of CUPTI-visible overhead per call at small BS. All default-tensor
   allocation and dlpack conversion is now done once, inside the
   `cache_key not in _compiled_kernels_*` block, and cached alongside the
   compiled kernel. Steady-state calls pass raw torch tensors directly to
   the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
   per-call torch.cuda.get_device_capability().

2) Pool + padding support on the ILP kernel
   gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
   accepts h0_indices and h0_out_indices, matching the MTP kernel's
   signature. Negative indices redirect to pool slot 0 (null buffer);
   writes go to a separate flat_write_idx so input and output pool slots
   can differ. The ILP launcher and gated_delta_rule wrapper thread the
   new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
   collapsed so pool+indices T=1 calls no longer detour through the
   heavier MTP kernel.

Design choice: kernel always takes indices (no constexpr switch).

Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200

Command:
    python benchmarks/bench_gdn_decode.py \
        --batch-size 1 4 8 16 32 64 128 256 512 \
        --num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
        --head-size 128 --dtype bfloat16 --warmup 20 --iters 200

Bf16State column results (us):

  BS | time
   1 |   3.71
   4 |   5.89
   8 |   9.18
  16 |  14.98
  32 |  26.56
  64 |  48.24
 128 |  89.66
 256 | 172.35
 512 | 337.60

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 20, 2026

📝 Walkthrough

Walkthrough

This PR extends FlashInfer's GDN decoding with pooled-state indexing support and introduces a wide-vector BF16 MTP kernel variant. Changes include modifying existing BF16 kernels to accept separate read/write pool indices, adding a new ILP=4 MTP kernel variant, introducing a wide-vector BF16 MTP kernel with intermediate-state caching, and updating dispatch logic to conditionally route between kernel backends.

Changes

Cohort / File(s) Summary
Main Decode Dispatch
flashinfer/gdn_decode.py
Modified BF16 T==1 control flow to handle both pool and non-pool modes; now conditionally selects initial_state_source and passes initial_state_indices and output_state_indices through the T==1 path.
BF16 ILP State Kernels
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Expanded existing ILP kernel to support pooled state buffers with separate read/write indices (h0_indices, h0_out_indices); added new MTP ILP=4 kernel variant; refactored Python dispatch logic in gated_delta_rule and gated_delta_rule_mtp to support pool-indexed reads/writes, kernel selection heuristics, and default index/output caching.
Wide-Vector BF16 MTP Kernel
flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py
New file implementing CUTLASS/CUTE wide-vector BF16 MTP kernel with per-thread vector width 8, two-phase execution (SMEM precompute + main compute loop), intra-warp reductions via butterfly patterns, and intermediate-state caching support; includes Python dispatch API gated_delta_rule_mtp_wide_vec.
Tests
tests/gdn/test_decode_delta_rule.py
Added optional wide-vector BF16 MTP test that conditionally imports and monkeypatches the new kernel variant; parametrized over sequence length, batch size, head count, and caching modes.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~80 minutes

Possibly related issues

  • [RFC] Unified GDN Decode/Prefill API #2687: Main changes implement BF16 pooled-state indexing, split read/write pool semantics, new BF16 MTP/wide-vector kernels, and dispatch consolidation—directly addressing the requested BF16 pool+indices, MTP/wide-vector backend, and dispatch work.

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • bkryu
  • nvmbreughe
  • yongwww
  • cyx-6
  • kahyunnam

Poem

🐰 Pooled states now dance with index pairs,
Wide vectors soar through cached affairs,
Split reads and writes, a clever feat,
The decode path's now complete! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Description check ⚠️ Warning PR description lacks a clear summary of implementation details and rationale; it contains only performance benchmarks and generic checklist items. Add a 📌 Description section explaining what the PR does (e.g., dispatcher optimization, kernel variants), why it's needed, and summarize the key changes across files.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: optimization of GDN decode BF16 state kernel for MTP with intermediate state caching support.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations for the BF16 GDN MTP decode kernels. Key improvements include a dynamic configuration mechanism to select optimal tiling and ILP (4 vs 8) based on workload size, and the addition of a new wide-vector kernel variant that utilizes 128-bit loads (8 BF16 elements) to maximize throughput for large batch sizes. The dispatcher logic has been updated to handle these new paths while reducing overhead by caching compiled kernels and minimizing redundant DLPack conversions. Feedback is provided regarding parameter naming consistency in the new wide-vector kernel implementation.

q: cute.Tensor,
k: cute.Tensor,
v: cute.Tensor,
b_gate: cute.Tensor,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other kernels in this file and the project (e.g., gdn_decode_bf16state_mtp_kernel), it would be clearer to name this parameter b instead of b_gate. The public-facing API gated_delta_rule_mtp_wide_vec already uses b.

This change should be propagated to the function body and the _run_wide_vec wrapper as well.

Suggested change
b_gate: cute.Tensor,
b: cute.Tensor,

ameynaik-hub and others added 2 commits April 20, 2026 11:57
…stic fix

Three coordinated changes to the BF16 GDN MTP decode path on B200:

1. New kernel `gated_delta_rule_mtp_wide_vec`
   (flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py):
   - 128 threads/CTA organised as 8 groups of 16 threads (baseline: 4 of 32)
   - vec=8 BF16 per thread -> LDG.128 / STG.128 (baseline: vec=4 -> .64)
   - ILP=4 V-rows/thread, 4-stage butterfly over 16-lane subgroup
   - No TMA, no persistent CTAs; the win is purely wider LSU transactions
   - Peaks at 77.6 % DRAM SOL at B=256, T=2, HV=64

2. Auto-dispatcher inside `gated_delta_rule_mtp`
   (flashinfer/gdn_kernels/gdn_decode_bf16_state.py):
   - New module-level `_WIDE_VEC_WORK_UNITS_THRESHOLD = 1024`
   - Routes to wide_vec when `B*HV >= 1024 and T >= 2 and K == V == 128`
   - T=1 and small-batch callers keep the existing baseline path verbatim,
     so the public API is source-and-ABI compatible.

3. B=4 HV=64 T=2 heuristic fix in `_get_bf16_mtp_config`:
   - Baseline's T=2 code path (gdn_decode_bf16_state.py) recomputes g/beta
     inline instead of reading from sGB (only T>2 pre-populates sGB). The
     inline softplus+exp+log+sigmoid stalls the ILP=8 pipeline at small
     work_units.
   - Fix: when `seq_len == 2 and work_units <= 256`, return (tile_v, 4)
     instead of (tile_v, 8). ILP=4 gives ~62 % occupancy vs ILP=8's ~37 %,
     covering the recompute latency.
   - Measured: B=4 HV=64 T=2: 11.20 us -> 9.63 us (1.17x).

Benchmark - Qwen3.5-397B-A17B Gated DeltaNet shape
(B200, HV=64, H_Q=H_K=16, K=V=128, BF16, cache_intermediate=True,
disable_state_update=True; T=1 uses state-update ON, no intermediate caching;
measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state with
CUPTI, 5 warmup + 50 bench iters per cell):

Wall-time (us), dispatcher's best-of-both output
(= baseline for B<=8, wide_vec for B>=16 at this HV; T=1 always baseline):

    B \ T       T=1      T=2      T=3      T=4      T=5      T=6      T=7      T=8
    -----    -------  -------  -------  -------  -------  -------  -------  -------
     1         3.46     5.44     5.79     6.67     8.83     9.79    10.61    11.50
     2         4.25     6.40     7.14     8.27    10.66    11.97    13.07    14.18
     4         5.79     9.63    10.53    12.51    16.03    18.11    20.32    22.22
     8         8.96    13.63    17.02    21.14    26.11    30.02    34.16    37.92
    16        15.22    21.20    27.23    33.73    41.23    47.89    55.15    62.11
    32        26.37    37.81    49.98    62.93    78.46    91.20   103.73   117.58
    64        47.76    70.69    93.76   117.92   146.42   172.80   197.82   225.44
   128        90.38   135.17   180.46   226.99   278.78   329.41   378.18   432.17
   256       173.65   262.98   351.06   440.75   542.37   641.63   739.60   846.40
   512       337.98   516.24   691.10   869.02    ERROR    ERROR    ERROR    ERROR

Wall-time (us), pre-session baseline (forced by monkey-patching
_WIDE_VEC_WORK_UNITS_THRESHOLD to 10**9):

    B \ T       T=1      T=2      T=3      T=4      T=5      T=6      T=7      T=8
    -----    -------  -------  -------  -------  -------  -------  -------  -------
     1         3.46     5.44     5.79     6.67     8.83     9.79    10.61    11.50
     2         4.25     6.40     7.14     8.27    10.66    11.97    13.07    14.18
     4         5.79     9.63    10.53    12.51    16.03    18.11    20.32    22.22
     8         8.96    13.63    17.02    21.14    26.11    30.02    34.16    37.92
    16        15.22    23.71    30.02    37.84    46.82    54.50    61.95    69.89
    32        26.37    42.56    54.75    69.01    85.41    99.74   114.13   128.50
    64        47.76    79.18   101.38   129.02   159.44   187.44   216.08   244.03
   128        90.38   149.34   193.09   247.38   305.87   361.33   417.20   473.81
   256       173.65   289.12   375.31   481.41   596.72   705.98   816.64   930.16
   512       337.98   568.32   741.25   952.97    ERROR    ERROR    ERROR    ERROR

Speedup (baseline / dispatcher). <=1.00x means dispatcher keeps baseline.
B=4 T=2 "1.00x" reflects the *new* heuristic (ilp=4) already in effect for
both columns; the pre-heuristic-fix baseline at that cell was 11.20 us
(1.16x slower than the post-fix 9.63 us).

    B \ T      T=1    T=2    T=3    T=4    T=5    T=6    T=7    T=8
    -----    -----  -----  -----  -----  -----  -----  -----  -----
     1        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
     2        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
     4        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
     8        1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x  1.00x
    16        1.00x  1.12x  1.10x  1.12x  1.14x  1.14x  1.12x  1.13x
    32        1.00x  1.13x  1.10x  1.10x  1.09x  1.09x  1.10x  1.09x
    64        1.00x  1.12x  1.08x  1.09x  1.09x  1.08x  1.09x  1.08x
   128        1.00x  1.10x  1.07x  1.09x  1.10x  1.10x  1.10x  1.10x
   256        1.00x  1.10x  1.07x  1.09x  1.10x  1.10x  1.10x  1.10x
   512        1.00x  1.10x  1.07x  1.10x  ERROR  ERROR  ERROR  ERROR

DRAM SOL (%, of 8.0 TB/s B200 peak) at the dispatcher's output:

    B \ T     T=1    T=2    T=3    T=4    T=5    T=6    T=7    T=8
    -----    -----  -----  -----  -----  -----  -----  -----  -----
     1       15.3   14.6   18.4   20.0   18.1   19.1   20.1   20.9
     2       24.9   24.9   29.8   32.2   30.0   31.2   32.6   33.9
     4       36.6   33.1   40.4   42.6   39.9   41.2   42.0   43.2
     8       47.3   46.8   50.0   50.4   49.0   49.7   50.0   50.6
    16       55.7   60.1   62.5   63.2   62.0   62.3   61.9   61.8
    32       64.3   67.4   68.1   67.7   65.2   65.5   65.8   65.3
    64       70.9   72.1   72.6   72.3   69.9   69.1   69.0   68.1
   128       75.0   75.4   75.5   75.1   73.4   72.5   72.2   71.1
   256       78.1   77.6   77.6   77.3   75.5   74.4   73.8   72.6
   512       80.2   79.0   78.8   78.4   ERROR  ERROR  ERROR  ERROR

Peak SOL: 80.2 % at B=512 T=1 (T=1 path unchanged; reflects the earlier
`7b7f1ac3` heuristic). Peak SOL on wide_vec-dispatched cells: 79.0 % at
B=512 T=2. Baseline-only peak at the same cells was ~71 % (B=512 T=2).
B=512 T>=5 hits the known cudaErrorIllegalAddress from
results/2026-04-03/benchmark_results.md (unrelated; pre-existing).

Correctness: 282 / 282 pytest configs pass on B=1..256, T=2..8, HV in {32, 64},
cache_intermediate_states in {True, False} via the official
`_test_gdn_decode_bf16_state_mtp_kernel` helper, reused across both
`test_gdn_decode_bf16_state_mtp_kernel` (baseline, unchanged) and the new
`test_gdn_decode_bf16_state_wide_vec_mtp_kernel` (monkey-patched symbol).

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After rebasing the wide_vec dispatcher onto bf16_baseline_fix, the original
T=1 / T=2 small-batch heuristic from 192fa39d / 7b7f1ac3 regressed because
the _get_bf16_mtp_config helper was dropped in the conflict resolution
(upstream only had _select_tile_v_for_mtp for ILP=8 tile sizing). This
commit re-adds the ILP=4 variant so the PR ships without the B=4 T=2
regression that would otherwise show in CI.

Changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py:

1. New @cute.kernel gdn_decode_bf16state_mtp_ilp4_kernel. Same math as the
   ILP=8 kernel, but processes 4 V-rows/group/iter instead of 8.
   - ILP=4: ~48 regs/thread -> ~62 % occupancy (vs ILP=8's ~37 %)
   - Minimal signature: single h0_indices (read == write). Split-pool writes
     are delegated to ILP=8 by the dispatcher; the ILP=4 path never needs
     h0_out_indices, so we skip the extra plumbing.

2. New @cute.jit run_gdn_decode_bf16state_mtp_ilp4 launcher, mirrors the
   existing ILP=8 launcher's tile_v/grid/SMEM computation but calls the
   new kernel.

3. New _get_bf16_mtp_config(batch_size, seq_len, num_v_heads, v_dim) helper:
     - work_units <= 128: (min(16, v_dim), 4)  # B<=2 at HV=64
     - seq_len == 2 and work_units <= 256: (tile_v, 4)  # covers B=4 HV=64
     - else: (tile_v, 8)
   The T=2 branch compensates for the ILP=8 pipeline stall from the inline
   g/beta recompute path (sGB is only populated for T > 2 in the ILP=8
   kernel; at small work_units ILP=8's 37 % occupancy cannot hide the
   softplus+exp+log+sigmoid latency).

4. gated_delta_rule_mtp dispatcher:
     - When output_state_indices is None: pick (tile_v, ilp_rows) via
       _get_bf16_mtp_config and route to the matching launcher.
     - When output_state_indices is not None: force ilp_rows=8 so the
       split-pool write path (h0_out_indices) stays on the ILP=8 kernel
       that supports it.
     - Cache key now includes ilp_rows so the two launchers don't collide.

Perf verification (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI,
5 warmup + 50 bench iters; cache_intermediate=True, disable_state_update=True
for T>=2; T=1 uses state-update ON):

              post-rebase before       post ILP=4 re-add
    B    T       (us)                   (us)
    1    1       3.46                    3.46
    1    2       5.43                    5.41
    2    1       4.26                    4.26
    2    2       6.51                    6.50
    4    1       5.82                    5.82
    4    2      11.36   <- +17 % regr.   9.68   <- back to reference
    8    1       9.25                    9.25
    8    2      13.95                   13.95
    16   1      15.06                   15.06
    16   2      21.50  (wide_vec)       21.50  (wide_vec)

Correctness: 30 / 30 test_gdn_decode_bf16_state_mtp_kernel configs pass at
B in {1, 2, 4, 8, 16}, T in {2, 4, 8}, cache_intermediate in {T, F} via
the official _test_gdn_decode_bf16_state_mtp_kernel helper. No test changes
needed — the existing test exercises the dispatcher and both ILP paths.

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub force-pushed the ameyn/optimize_gdn_bf16_state_mtp_apr branch from 192fa39 to 16f6f14 Compare April 20, 2026 19:51
@ameynaik-hub ameynaik-hub marked this pull request as ready for review April 20, 2026 19:54
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py`:
- Around line 487-490: The final-state writeback currently skips updating
initial_state_source when caching is enabled (via effective_disable_final True),
which can silently drop the final state; update the conditional logic around
disable_state_update/cache_intermediate_states so that when disable_state_update
is False and cache_intermediate_states (or intermediate_states_buffer is
present) you explicitly copy intermediate_states_buffer[:, T-1] back into
initial_state_source before skipping the normal write-back (or alternatively
force wide-vec caching into verify-only mode). Apply the same fix to the other
identical final-state write-back site that follows the same pattern (the later
block using effective_disable_final/cutlass.const_expr).
- Around line 239-252: The softplus branch currently computes exp(beta_x_pre)
unconditionally which can overflow; clamp beta_x_pre to at most
softplus_threshold (e.g., beta_x_pre_clamped = min(beta_x_pre,
softplus_threshold)) before calling cute.exp/cute.log to avoid inf, then compute
exp_beta_x_pre and softplus_val_pre from beta_x_pre_clamped and keep the
existing use_softplus_pre mask and final softplus_x_pre selection (references:
softplus_beta, x_pre, beta_x_pre, softplus_threshold, exp_beta_x_pre,
softplus_val_pre, use_softplus_pre, softplus_x_pre, cute.exp, cute.log).
- Around line 618-630: The code currently creates a contiguous copy of
intermediate_states_buffer via intermediate_states =
intermediate_states.contiguous(), which means kernel writes to the copy and the
original buffer remains stale; change the logic in the cache_intermediate_states
branch to either assert that intermediate_states_buffer.is_contiguous() (and
raise/complain if not) before doing the reshape, or if you allow non-contiguous
inputs, allocate a contiguous working tensor and after kernel execution
copy/write the updated data back into intermediate_states_buffer (preserving
dtype torch.bfloat16 and the original shape), using the same HV_val/V_val/K_val
layout so caller’s buffer[:, T-1] contains the final state; refer to
intermediate_states_buffer, intermediate_states and cache_intermediate_states to
locate where to implement the assert or the copy-back.
- Around line 640-720: The cache_key currently omits per-tensor device, dtype
and stride info and also caches a reusable default_output that gets returned
across calls; update the cache key construction (the variable cache_key used to
index _compiled_kernels_wide_vec) to include q.device.type and q.device.index
plus the dtype and .stride() (as tuples) for every tensor passed into
from_dlpack/cute.compile (h0_source, intermediate_states, q, k, v, a, b, A_log,
dt_bias, output and initial_state_indices), and stop storing/reusing a single
default_output in the cache — only cache the compiled kernel and
default_indices, and always allocate a fresh output buffer when output is None
before calling from_dlpack/cute.compile so each call gets its own output tensor.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 3076-3078: The code converts output_state_indices to torch.int32
but forgets to normalize initial_state_indices, causing compiled kernels (which
expect int32 index tensors) to fail when callers pass int64; update the logic
around use_pool/initial_state_indices to check if initial_state_indices is not
None and, if so, call initial_state_indices =
initial_state_indices.to(torch.int32) (mirror the output_state_indices
conversion), and apply the same change at the second occurrence mentioned (the
block around lines 3331-3332) so both code paths use int32 index tensors for the
compiled kernels.
- Around line 2678-2680: The comments in gdn_decode_bf16_state.py next to the
sQ, sK, and sGB size calculations use the Unicode multiplication sign "×" which
triggers RUF003; update those comment strings to use plain ASCII "x" (e.g.,
change "T × (K+8) × 4 bytes" to "T x (K+8) x 4 bytes") for the sQ, sK, and sGB
lines so linting passes and the intent remains the same; also scan the
surrounding comments in the same block for any other "×" characters and replace
them with "x".
- Around line 2554-2563: The flattened index into intermediate_states_buffer
uses T instead of cache_steps, causing out-of-bounds writes when cache_steps >
T; update the computation of flat_idx to use cache_steps (i.e., flat_idx =
cache_idx * cache_steps * HV + i_t * HV + i_hv) wherever flat_idx is computed
(reference symbols: intermediate_states_buffer, cache_idx, cache_steps, T, HV,
i_t, i_hv, flat_idx) and ensure cache_steps is declared/passed as a cutlass
constexpr through the kernel/launcher/cache key so compiled kernels differ when
the buffer step dimension changes; make the same change at the other occurrence
noted (the second block around the ILP=4 handling).
- Around line 3131-3133: The cached compile-time buffer (cache["output"]) is
being returned when output is None, causing later calls to overwrite previously
returned tensors; update the logic in the decode function(s) so that
cache["output"] is used only to infer shape/for compilation but you always
allocate and return a fresh tensor (e.g., create a new torch.empty with the same
shape/dtype/device as cache["output"] when output is None) and do not return the
cached object itself; apply the same change to the analogous spots that create
default_output/default_indices (the occurrences around the variables
default_output, default_indices and the other similar blocks at the other noted
locations).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 447d5233-bc08-40ba-89f2-f250cc53cb5f

📥 Commits

Reviewing files that changed from the base of the PR and between ce02358 and 16f6f14.

📒 Files selected for processing (4)
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py
  • tests/gdn/test_decode_delta_rule.py

Comment on lines +239 to +252
beta_x_pre = softplus_beta * x_pre
exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True)
softplus_val_pre = (cutlass.Float32(1.0) / softplus_beta) * cute.log(
cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True
)
use_softplus_pre = (
cutlass.Float32(1.0)
if beta_x_pre <= softplus_threshold
else cutlass.Float32(0.0)
)
softplus_x_pre = (
use_softplus_pre * softplus_val_pre
+ (cutlass.Float32(1.0) - use_softplus_pre) * x_pre
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid evaluating the overflowing softplus branch.

This computes exp(beta_x_pre) before applying the threshold. For large positive x_pre, exp_beta_x_pre can become inf, and the arithmetic mask can produce 0 * inf -> NaN.

Proposed fix
-                exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True)
-                softplus_val_pre = (cutlass.Float32(1.0) / softplus_beta) * cute.log(
-                    cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True
-                )
-                use_softplus_pre = (
-                    cutlass.Float32(1.0)
-                    if beta_x_pre <= softplus_threshold
-                    else cutlass.Float32(0.0)
-                )
-                softplus_x_pre = (
-                    use_softplus_pre * softplus_val_pre
-                    + (cutlass.Float32(1.0) - use_softplus_pre) * x_pre
-                )
+                if beta_x_pre <= softplus_threshold:
+                    exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True)
+                    softplus_x_pre = (
+                        cutlass.Float32(1.0) / softplus_beta
+                    ) * cute.log(cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True)
+                else:
+                    softplus_x_pre = x_pre
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 239 -
252, The softplus branch currently computes exp(beta_x_pre) unconditionally
which can overflow; clamp beta_x_pre to at most softplus_threshold (e.g.,
beta_x_pre_clamped = min(beta_x_pre, softplus_threshold)) before calling
cute.exp/cute.log to avoid inf, then compute exp_beta_x_pre and softplus_val_pre
from beta_x_pre_clamped and keep the existing use_softplus_pre mask and final
softplus_x_pre selection (references: softplus_beta, x_pre, beta_x_pre,
softplus_threshold, exp_beta_x_pre, softplus_val_pre, use_softplus_pre,
softplus_x_pre, cute.exp, cute.log).

Comment on lines +487 to +490
# Final state write-back: skip when caching (inter[T-1] already has it)
if cutlass.const_expr(
not disable_state_update and not cache_intermediate_states
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve state-update semantics when caching is enabled.

With intermediate_states_buffer set, effective_disable_final=True skips updating initial_state_source even when disable_state_update=False. If this path is reached through the existing MTP dispatcher, callers can silently stop seeing the final state in the input pool; either copy buffer[:, T-1] back for update-enabled calls or gate wide-vec caching to verify-only mode.

Also applies to: 631-635

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 487 -
490, The final-state writeback currently skips updating initial_state_source
when caching is enabled (via effective_disable_final True), which can silently
drop the final state; update the conditional logic around
disable_state_update/cache_intermediate_states so that when disable_state_update
is False and cache_intermediate_states (or intermediate_states_buffer is
present) you explicitly copy intermediate_states_buffer[:, T-1] back into
initial_state_source before skipping the normal write-back (or alternatively
force wide-vec caching into verify-only mode). Apply the same fix to the other
identical final-state write-back site that follows the same pattern (the later
block using effective_disable_final/cutlass.const_expr).

Comment on lines +618 to +630
h0_source = initial_state_source.reshape(pool_size * HV_val, V_val, K_val)

cache_intermediate_states = intermediate_states_buffer is not None
if cache_intermediate_states:
buffer_size = intermediate_states_buffer.shape[0]
cache_steps = intermediate_states_buffer.shape[1]
assert cache_steps >= T_val
assert intermediate_states_buffer.dtype == torch.bfloat16
intermediate_states = intermediate_states_buffer.reshape(
buffer_size * cache_steps * HV_val, V_val, K_val
)
if not intermediate_states.is_contiguous():
intermediate_states = intermediate_states.contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file and the specific lines in context
cat -n flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py | sed -n '610,650p'

Repository: flashinfer-ai/flashinfer

Length of output: 1789


🏁 Script executed:

# Let's also check the entire function to understand the context
wc -l flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 128


🏁 Script executed:

# Search for the function containing these lines and understand the kernel launch pattern
ast-grep --pattern 'def $FUNC_NAME($$$) {
  $$$
  intermediate_states = intermediate_states_buffer.reshape($$$)
  $$$
}'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for where intermediate_states is used after this reshape
rg "intermediate_states" flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py -A 3 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 4702


Reject non-contiguous intermediate_states_buffer or ensure writeback.

The kernel receives a reshaped view of intermediate_states_buffer. If that view is non-contiguous, .contiguous() allocates a copy; the kernel then updates the copy while intermediate_states_buffer remains stale. When cache_intermediate_states=True, the docstring requires the caller to read the final state from buffer[:, T-1], but non-contiguous inputs break this contract silently. Either assert contiguity before reshape, or copy results back to the original buffer after kernel launch.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 618 -
630, The code currently creates a contiguous copy of intermediate_states_buffer
via intermediate_states = intermediate_states.contiguous(), which means kernel
writes to the copy and the original buffer remains stale; change the logic in
the cache_intermediate_states branch to either assert that
intermediate_states_buffer.is_contiguous() (and raise/complain if not) before
doing the reshape, or if you allow non-contiguous inputs, allocate a contiguous
working tensor and after kernel execution copy/write the updated data back into
intermediate_states_buffer (preserving dtype torch.bfloat16 and the original
shape), using the same HV_val/V_val/K_val layout so caller’s buffer[:, T-1]
contains the final state; refer to intermediate_states_buffer,
intermediate_states and cache_intermediate_states to locate where to implement
the assert or the copy-back.

Comment on lines +640 to +720
cache_key = (
"v3_mtp_bf16",
B_val,
T_val,
H_val,
HV_val,
K_val,
V_val,
pool_size,
effective_disable_final,
cache_intermediate_states,
use_qk_l2norm_in_kernel,
scale,
softplus_beta,
softplus_threshold,
use_packed_fma,
)
if cache_key not in _compiled_kernels_wide_vec:
default_indices = torch.arange(B_val, dtype=torch.int32, device=q.device)
default_output = torch.empty(
B_val, T_val, HV_val, V_val, device=q.device, dtype=q.dtype
)

if initial_state_indices is None:
initial_state_indices = default_indices
if output is None:
output = default_output

h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True)
inter_ = from_dlpack(intermediate_states, assumed_align=32, enable_tvm_ffi=True)
q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True)
k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True)
v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True)
a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True)
b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True)
A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True)
dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True)
o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True)
h0_idx_ = from_dlpack(
initial_state_indices, assumed_align=32, enable_tvm_ffi=True
)

_compiled_kernels_wide_vec[cache_key] = {
"compiled": cute.compile(
_run_wide_vec,
h_,
inter_,
A_log_,
a_,
dt_bias_,
q_,
k_,
v_,
b_,
o_,
h0_idx_,
softplus_beta,
softplus_threshold,
scale,
HV_val,
B_val,
T_val,
H_val,
K_val,
V_val,
use_qk_l2norm_in_kernel,
effective_disable_final,
cache_intermediate_states,
use_packed_fma,
stream,
options="--enable-tvm-ffi --generate-line-info --opt-level 3",
),
"default_indices": default_indices,
"output": default_output,
}

cache = _compiled_kernels_wide_vec[cache_key]
if initial_state_indices is None:
initial_state_indices = cache["default_indices"]
if output is None:
output = cache["output"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find and read the relevant file
fd "gdn_decode_bf16_state_wide_vec.py"

Repository: flashinfer-ai/flashinfer

Length of output: 124


🏁 Script executed:

# Read the file to examine the code around lines 640-720
wc -l flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 128


🏁 Script executed:

# Read lines 640-720 to see the code in question
sed -n '640,720p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 2828


🏁 Script executed:

# Read more context to see the function signature and parameters
sed -n '550,650p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 3387


🏁 Script executed:

# Check the end of the function to see how the cache is used
sed -n '720,736p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 346


🏁 Script executed:

# Check if the function has any decorators
sed -n '570,595p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 990


🏁 Script executed:

# Let's check if this is part of a high-level API by looking at imports and checking if decorators are used
head -50 flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 2216


🏁 Script executed:

# Check if other similar functions in the codebase use decorators
rg "@flashinfer_api|@backend_requirement" flashinfer/gdn_kernels/ -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if these decorators exist anywhere in the codebase
rg "@flashinfer_api|@backend_requirement" flashinfer/ | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1002


🏁 Script executed:

# Check the decorator definitions in utils.py
rg "def flashinfer_api|def backend_requirement" flashinfer/utils.py -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 587


🏁 Script executed:

# Search for flashinfer_api decorator definition
rg "def flashinfer_api" flashinfer/ -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 1390


🏁 Script executed:

# Check where this function is exported/used as a public API
rg "gated_delta_rule_mtp_wide_vec" flashinfer/ -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 1567


🏁 Script executed:

# Let me verify the actual problem by checking if different devices/dtypes could use the same cache
sed -n '600,660p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py

Repository: flashinfer-ai/flashinfer

Length of output: 2255


🏁 Script executed:

# Check if similar kernels in gdn_kernels have these issues
ls -la flashinfer/gdn_kernels/

Repository: flashinfer-ai/flashinfer

Length of output: 703


🏁 Script executed:

# Check if similar patterns exist in other files in gdn_kernels
rg "_compiled_kernels" flashinfer/gdn_kernels/ -B 5 -A 15 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 8718


Fix cache key to include full tensor metadata and allocate fresh outputs on every call.

The compiled kernel depends on tensor device, dtype, and strides via DLPack conversion (lines 671–677), but the cache key omits device type/index, all tensor dtypes, and strides. This causes incorrect cache hits across devices or dtype variants.

More critically, when output=None, the code caches and reuses the same default_output tensor (lines 660–665, 710–712, 718–720). A second call with matching cache key returns the same tensor object, causing the kernel to overwrite the first call's result.

Extend the cache key to include q.device.type, q.device.index, dtypes of all tensors passed to cute.compile(), initial_state_indices.dtype, and strides of all tensors. Always allocate a fresh output buffer unless the caller explicitly provides one.

Sketch of safer structure
+    if output is None:
+        output = torch.empty(
+            B_val, T_val, HV_val, V_val, device=q.device, dtype=q.dtype
+        )
+
+    indices_dtype = (
+        initial_state_indices.dtype if initial_state_indices is not None else torch.int32
+    )
     cache_key = (
         "v3_mtp_bf16",
+        q.device.type,
+        q.device.index,
+        q.dtype,
+        k.dtype,
+        v.dtype,
+        a.dtype,
+        b.dtype,
+        A_log.dtype,
+        dt_bias.dtype,
+        initial_state_source.dtype,
+        indices_dtype,
+        output.dtype,
+        q.stride(),
+        k.stride(),
+        v.stride(),
+        a.stride(),
+        b.stride(),
+        initial_state_source.stride(),
         B_val,
         T_val,
         H_val,
@@
-            "output": default_output,
         }
@@
-    if output is None:
-        output = cache["output"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 640 -
720, The cache_key currently omits per-tensor device, dtype and stride info and
also caches a reusable default_output that gets returned across calls; update
the cache key construction (the variable cache_key used to index
_compiled_kernels_wide_vec) to include q.device.type and q.device.index plus the
dtype and .stride() (as tuples) for every tensor passed into
from_dlpack/cute.compile (h0_source, intermediate_states, q, k, v, a, b, A_log,
dt_bias, output and initial_state_indices), and stop storing/reusing a single
default_output in the cache — only cache the compiled kernel and
default_indices, and always allocate a fresh output buffer when output is None
before calling from_dlpack/cute.compile so each call gets its own output tensor.

Comment on lines +2554 to +2563
if cutlass.const_expr(cache_intermediate_states):
for i in cutlass.range_constexpr(vec_size):
r_hb4_0[i] = cutlass.BFloat16(r_h[0, i])
r_hb4_1[i] = cutlass.BFloat16(r_h[1, i])
r_hb4_2[i] = cutlass.BFloat16(r_h[2, i])
r_hb4_3[i] = cutlass.BFloat16(r_h[3, i])

if cutlass.const_expr(cache_intermediate_states):
flat_idx = cache_idx * T * HV + i_t * HV + i_hv
ita = cute.local_tile(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use cache_steps, not T, for intermediate-state pool stride.

intermediate_states_buffer allows shape[1] >= T, but the ILP=4 kernel flattens with cache_idx * T * HV. For cache_steps > T and cache_idx > 0, this writes into the wrong pool slot. Pass cache_steps as a constexpr and use it in the flattened index.

Proposed fix shape
 def gdn_decode_bf16state_mtp_ilp4_kernel(
@@
     T: cutlass.Constexpr[int],
+    cache_steps: cutlass.Constexpr[int],
@@
-                    flat_idx = cache_idx * T * HV + i_t * HV + i_hv
+                    flat_idx = cache_idx * cache_steps * HV + i_t * HV + i_hv

Apply the same cache_steps parameter through the launcher/cache key so compiled kernels are distinct when the buffer step dimension changes.

Also applies to: 3339-3348

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2554 - 2563,
The flattened index into intermediate_states_buffer uses T instead of
cache_steps, causing out-of-bounds writes when cache_steps > T; update the
computation of flat_idx to use cache_steps (i.e., flat_idx = cache_idx *
cache_steps * HV + i_t * HV + i_hv) wherever flat_idx is computed (reference
symbols: intermediate_states_buffer, cache_idx, cache_steps, T, HV, i_t, i_hv,
flat_idx) and ensure cache_steps is declared/passed as a cutlass constexpr
through the kernel/launcher/cache key so compiled kernels differ when the buffer
step dimension changes; make the same change at the other occurrence noted (the
second block around the ILP=4 handling).

Comment on lines +2678 to +2680
4 * T * (K + 8) # sQ: T × (K+8) × 4 bytes (shared, one copy)
+ 4 * T * (K + 8) # sK: same
+ 4 * T * 2 # sGB: T × 2 × 4 bytes (shared)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace Unicode multiplication signs in comments.

Ruff flags the × characters here with RUF003; use plain x to keep lint clean.

Proposed fix
-            4 * T * (K + 8)  # sQ: T × (K+8) × 4 bytes (shared, one copy)
+            4 * T * (K + 8)  # sQ: T x (K+8) x 4 bytes (shared, one copy)
             + 4 * T * (K + 8)  # sK: same
-            + 4 * T * 2  # sGB: T × 2 × 4 bytes (shared)
+            + 4 * T * 2  # sGB: T x 2 x 4 bytes (shared)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
4 * T * (K + 8) # sQ: T × (K+8) × 4 bytes (shared, one copy)
+ 4 * T * (K + 8) # sK: same
+ 4 * T * 2 # sGB: T × 2 × 4 bytes (shared)
4 * T * (K + 8) # sQ: T x (K+8) x 4 bytes (shared, one copy)
4 * T * (K + 8) # sK: same
4 * T * 2 # sGB: T x 2 x 4 bytes (shared)
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 2678-2678: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 2678-2678: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 2680-2680: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 2680-2680: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2678 - 2680,
The comments in gdn_decode_bf16_state.py next to the sQ, sK, and sGB size
calculations use the Unicode multiplication sign "×" which triggers RUF003;
update those comment strings to use plain ASCII "x" (e.g., change "T × (K+8) × 4
bytes" to "T x (K+8) x 4 bytes") for the sQ, sK, and sGB lines so linting passes
and the intent remains the same; also scan the surrounding comments in the same
block for any other "×" characters and replace them with "x".

Comment on lines +3076 to +3078
use_pool = initial_state_indices is not None
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize initial_state_indices to int32 too.

The kernels are compiled with int32 index tensors, but only output_state_indices is converted. Passing the usual PyTorch int64 indices for initial_state_indices can fail the compiled call or read the wrong slots.

Proposed fix
     use_pool = initial_state_indices is not None
+    if initial_state_indices is not None and initial_state_indices.dtype != torch.int32:
+        initial_state_indices = initial_state_indices.to(torch.int32)
     if output_state_indices is not None and output_state_indices.dtype != torch.int32:
         output_state_indices = output_state_indices.to(torch.int32)
@@
+    if initial_state_indices is not None and initial_state_indices.dtype != torch.int32:
+        initial_state_indices = initial_state_indices.to(torch.int32)
     if output_state_indices is not None and output_state_indices.dtype != torch.int32:
         output_state_indices = output_state_indices.to(torch.int32)

Also applies to: 3331-3332

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 3076 - 3078,
The code converts output_state_indices to torch.int32 but forgets to normalize
initial_state_indices, causing compiled kernels (which expect int32 index
tensors) to fail when callers pass int64; update the logic around
use_pool/initial_state_indices to check if initial_state_indices is not None
and, if so, call initial_state_indices = initial_state_indices.to(torch.int32)
(mirror the output_state_indices conversion), and apply the same change at the
second occurrence mentioned (the block around lines 3331-3332) so both code
paths use int32 index tensors for the compiled kernels.

Comment on lines +3131 to +3133
default_output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)
default_indices = torch.arange(B, dtype=torch.int32, device=q.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t return the cached compile-time output buffer.

When output is None, same-shape calls reuse cache["output"], so a later call overwrites any previously returned tensor. Keep the cached tensor only for compilation and allocate a fresh result unless the caller explicitly passes output.

Proposed fix
-            "output": default_output,
+            "_compile_output": default_output,
@@
     if output is None:
-        output = cache["output"]
+        output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)
@@
-                "output": default_output,
+                "_compile_output": default_output,
@@
-                "output": default_output,
+                "_compile_output": default_output,
@@
     if output is None:
-        output = cache["output"]
+        output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)

Also applies to: 3192-3193, 3433-3435, 3531-3532

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 3131 - 3133,
The cached compile-time buffer (cache["output"]) is being returned when output
is None, causing later calls to overwrite previously returned tensors; update
the logic in the decode function(s) so that cache["output"] is used only to
infer shape/for compilation but you always allocate and return a fresh tensor
(e.g., create a new torch.empty with the same shape/dtype/device as
cache["output"] when output is None) and do not return the cached object itself;
apply the same change to the analogous spots that create
default_output/default_indices (the occurrences around the variables
default_output, default_indices and the other similar blocks at the other noted
locations).

@ameynaik-hub
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !572 has been created, and the CI pipeline #49025840 is currently running. I'll report back once the pipeline job completes.

ameynaik-hub added a commit to ameynaik-hub/flashinfer that referenced this pull request Apr 28, 2026
Follow-up to PR flashinfer-ai#3127. The shipped wide_vec kernel has a fixed tile_v=128
so its grid is only B*HV CTAs — at B <= 8 (HV=64) that's <= 512 CTAs on a
148-SM B200 (< 3.5 waves), and the dispatcher falls back to the baseline
ILP=4/8 path (LDG.E.64). This commit parameterizes tile_v so the grid gains
a V-tile dimension, letting the wide_vec kernel (LDG.E.128 on H) reach down
into small-batch sizes that were previously starved of SM parallelism.

Changes:

1. gdn_decode_bf16_state_wide_vec.py — tile_v is now a per-call constexpr:
   - Kernel decodes (i_n, i_hv, i_v) from block_idx; v_base offset by
     `i_v * tile_v`. ROWS_PER_GROUP and ITERS_PER_GROUP promoted from
     module-level constants to per-kernel constexprs.
   - Launcher computes grid = B * HV * (V / tile_v) and passes tile_v +
     num_v_tiles to the kernel.
   - Public wrapper `gated_delta_rule_mtp_wide_vec` accepts `tile_v: int = 128`
     (keyword-only default; positional callers unchanged).
   - Valid tile_v values: {32, 64, 128}. Asserted at the launcher entry.
   - Subgroup layout unchanged (16 threads x 8 BF16 per thread = LDG.E.128),
     so every tile_v variant uses the same per-thread memory-op widths.
   - Bit-exact with the previous fixed-tile_v=128 kernel at tile_v=128
     (verified: max_abs_diff = 0.0 over 252 pytest configs).

2. gdn_decode_bf16_state.py — new `_select_wide_vec_tile_v(B, HV, V=128)`
   helper picks tile_v by work_units = B * HV:
     work_units >= 1024 -> tile_v = 128  (unchanged: B >= 16 at HV=64)
     work_units >=  512 -> tile_v =  64  (new:   B =  8 at HV=64)
     work_units >=  128 -> tile_v =  32  (new:   B =  2..4 at HV=64)
     below              -> None (baseline ILP=4/8 as before)
   `_WIDE_VEC_WORK_UNITS_THRESHOLD` lowered 1024 -> 128 (module-scope
   constant kept for external benchmarks that monkey-patch it; actual
   dispatch uses the new picker). Dispatcher in `gated_delta_rule_mtp`
   now calls the picker and passes tile_v into `gated_delta_rule_mtp_wide_vec`.

3. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_wide_vec_mtp_kernel`
   parametrized over `tile_v in {32, 64, 128}`. num_v_heads restricted to
   {64} (Qwen3.5 production shape; HV=32 exploratory coverage is not needed
   for the small-batch path). Total configs: 9 batch x 7 T x 2 cache x 3 tile_v
   = 378.

Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI 5 warmup + 50 bench iters,
cache_intermediate=True, disable_state_update=True for T>=2; T=1 unchanged):

                post- 3127 before      post tile_v picker
    B    T       (us)                  (us)         speedup
    1    2       5.44                  5.46          1.00x  (picker -> tv=32 ties baseline)
    2    2       6.40                  6.34          1.01x  (picker -> tv=32)
    4    2       9.63                  8.14          1.18x  (picker -> tv=32)
    8    2      13.63                 12.74          1.07x  (picker -> tv=64)
    16   2      21.20                 21.18          1.00x  (tv=128, unchanged)
    32   2      37.81                 37.89          1.00x  (tv=128, unchanged)
    64   2      70.69                 70.69          1.00x  (tv=128, unchanged)

Biggest headline: B=4 T=2 goes from 9.63 us to 8.14 us (1.18x) by routing to
wide_vec tile_v=32, which was previously dispatched to the baseline ILP=4
kernel. Cumulative against origin/main at B=4 T=2: 11.14 us -> 8.14 us = 1.37x.

Correctness: 378 / 378 new pytest configs pass
(test_gdn_decode_bf16_state_wide_vec_mtp_kernel). Equivalent 1008-config
sanity run across tile_v = {32, 64, 128} x HV = {32, 64} also passed
during development (not in CI to keep the default matrix size reasonable).

AI-assisted by Claude Code.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
@ameynaik-hub ameynaik-hub marked this pull request as draft April 28, 2026 06:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants