Ameyn/wide vec t1#3147
Conversation
📝 WalkthroughWalkthroughImplements support for separate input and output state indices in BF16-accelerated gated-delta-rule decoding, enabling split-pool dispatch where outputs are written to different pool slots than inputs are read from. Unifies both T==1 and T>1 code paths to route through consistent pool/index handling, updates the benchmark wrapper with a ChangesBF16 Split-Pool State Dispatch
Estimated Code Review Effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly Related Issues
Possibly Related PRs
Suggested Labels
Suggested Reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. 👉 Get your free trial and get 200 agent minutes per Slack user (a $50 value). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a wide-vector BF16 GDN MTP decode kernel and an ILP=4 variant to optimize performance across various batch sizes and work units. It also refactors the dispatch logic to support pool-based state management with indices and includes updated tests. Feedback from the review focuses on the significant code duplication across the new kernels and the repetitive nature of the memory write-back logic, suggesting that these be refactored into parameterized implementations or loops to improve maintainability.
e7bee6a to
dc87e93
Compare
…to ILP kernel
Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:
1) Per-call Python overhead fix
Move torch.arange / torch.empty / from_dlpack / get_device_capability out
of the steady-state call path. These were previously called on every
invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
of CUPTI-visible overhead per call at small BS. All default-tensor
allocation and dlpack conversion is now done once, inside the
`cache_key not in _compiled_kernels_*` block, and cached alongside the
compiled kernel. Steady-state calls pass raw torch tensors directly to
the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
per-call torch.cuda.get_device_capability().
2) Pool + padding support on the ILP kernel
gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
accepts h0_indices and h0_out_indices, matching the MTP kernel's
signature. Negative indices redirect to pool slot 0 (null buffer);
writes go to a separate flat_write_idx so input and output pool slots
can differ. The ILP launcher and gated_delta_rule wrapper thread the
new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
collapsed so pool+indices T=1 calls no longer detour through the
heavier MTP kernel.
Design choice: kernel always takes indices (no constexpr switch).
Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200
Command:
python benchmarks/bench_gdn_decode.py \
--batch-size 1 4 8 16 32 64 128 256 512 \
--num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200
Bf16State column results (us):
BS | time
1 | 3.71
4 | 5.89
8 | 9.18
16 | 14.98
32 | 26.56
64 | 48.24
128 | 89.66
256 | 172.35
512 | 337.60
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…stic fix
Three coordinated changes to the BF16 GDN MTP decode path on B200:
1. New kernel `gated_delta_rule_mtp_wide_vec`
(flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py):
- 128 threads/CTA organised as 8 groups of 16 threads (baseline: 4 of 32)
- vec=8 BF16 per thread -> LDG.128 / STG.128 (baseline: vec=4 -> .64)
- ILP=4 V-rows/thread, 4-stage butterfly over 16-lane subgroup
- No TMA, no persistent CTAs; the win is purely wider LSU transactions
- Peaks at 77.6 % DRAM SOL at B=256, T=2, HV=64
2. Auto-dispatcher inside `gated_delta_rule_mtp`
(flashinfer/gdn_kernels/gdn_decode_bf16_state.py):
- New module-level `_WIDE_VEC_WORK_UNITS_THRESHOLD = 1024`
- Routes to wide_vec when `B*HV >= 1024 and T >= 2 and K == V == 128`
- T=1 and small-batch callers keep the existing baseline path verbatim,
so the public API is source-and-ABI compatible.
3. B=4 HV=64 T=2 heuristic fix in `_get_bf16_mtp_config`:
- Baseline's T=2 code path (gdn_decode_bf16_state.py) recomputes g/beta
inline instead of reading from sGB (only T>2 pre-populates sGB). The
inline softplus+exp+log+sigmoid stalls the ILP=8 pipeline at small
work_units.
- Fix: when `seq_len == 2 and work_units <= 256`, return (tile_v, 4)
instead of (tile_v, 8). ILP=4 gives ~62 % occupancy vs ILP=8's ~37 %,
covering the recompute latency.
- Measured: B=4 HV=64 T=2: 11.20 us -> 9.63 us (1.17x).
Benchmark - Qwen3.5-397B-A17B Gated DeltaNet shape
(B200, HV=64, H_Q=H_K=16, K=V=128, BF16, cache_intermediate=True,
disable_state_update=True; T=1 uses state-update ON, no intermediate caching;
measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state with
CUPTI, 5 warmup + 50 bench iters per cell):
Wall-time (us), dispatcher's best-of-both output
(= baseline for B<=8, wide_vec for B>=16 at this HV; T=1 always baseline):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 21.20 27.23 33.73 41.23 47.89 55.15 62.11
32 26.37 37.81 49.98 62.93 78.46 91.20 103.73 117.58
64 47.76 70.69 93.76 117.92 146.42 172.80 197.82 225.44
128 90.38 135.17 180.46 226.99 278.78 329.41 378.18 432.17
256 173.65 262.98 351.06 440.75 542.37 641.63 739.60 846.40
512 337.98 516.24 691.10 869.02 ERROR ERROR ERROR ERROR
Wall-time (us), pre-session baseline (forced by monkey-patching
_WIDE_VEC_WORK_UNITS_THRESHOLD to 10**9):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 23.71 30.02 37.84 46.82 54.50 61.95 69.89
32 26.37 42.56 54.75 69.01 85.41 99.74 114.13 128.50
64 47.76 79.18 101.38 129.02 159.44 187.44 216.08 244.03
128 90.38 149.34 193.09 247.38 305.87 361.33 417.20 473.81
256 173.65 289.12 375.31 481.41 596.72 705.98 816.64 930.16
512 337.98 568.32 741.25 952.97 ERROR ERROR ERROR ERROR
Speedup (baseline / dispatcher). <=1.00x means dispatcher keeps baseline.
B=4 T=2 "1.00x" reflects the *new* heuristic (ilp=4) already in effect for
both columns; the pre-heuristic-fix baseline at that cell was 11.20 us
(1.16x slower than the post-fix 9.63 us).
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
2 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
4 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
8 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
16 1.00x 1.12x 1.10x 1.12x 1.14x 1.14x 1.12x 1.13x
32 1.00x 1.13x 1.10x 1.10x 1.09x 1.09x 1.10x 1.09x
64 1.00x 1.12x 1.08x 1.09x 1.09x 1.08x 1.09x 1.08x
128 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
256 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
512 1.00x 1.10x 1.07x 1.10x ERROR ERROR ERROR ERROR
DRAM SOL (%, of 8.0 TB/s B200 peak) at the dispatcher's output:
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 15.3 14.6 18.4 20.0 18.1 19.1 20.1 20.9
2 24.9 24.9 29.8 32.2 30.0 31.2 32.6 33.9
4 36.6 33.1 40.4 42.6 39.9 41.2 42.0 43.2
8 47.3 46.8 50.0 50.4 49.0 49.7 50.0 50.6
16 55.7 60.1 62.5 63.2 62.0 62.3 61.9 61.8
32 64.3 67.4 68.1 67.7 65.2 65.5 65.8 65.3
64 70.9 72.1 72.6 72.3 69.9 69.1 69.0 68.1
128 75.0 75.4 75.5 75.1 73.4 72.5 72.2 71.1
256 78.1 77.6 77.6 77.3 75.5 74.4 73.8 72.6
512 80.2 79.0 78.8 78.4 ERROR ERROR ERROR ERROR
Peak SOL: 80.2 % at B=512 T=1 (T=1 path unchanged; reflects the earlier
`7b7f1ac3` heuristic). Peak SOL on wide_vec-dispatched cells: 79.0 % at
B=512 T=2. Baseline-only peak at the same cells was ~71 % (B=512 T=2).
B=512 T>=5 hits the known cudaErrorIllegalAddress from
results/2026-04-03/benchmark_results.md (unrelated; pre-existing).
Correctness: 282 / 282 pytest configs pass on B=1..256, T=2..8, HV in {32, 64},
cache_intermediate_states in {True, False} via the official
`_test_gdn_decode_bf16_state_mtp_kernel` helper, reused across both
`test_gdn_decode_bf16_state_mtp_kernel` (baseline, unchanged) and the new
`test_gdn_decode_bf16_state_wide_vec_mtp_kernel` (monkey-patched symbol).
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After rebasing the wide_vec dispatcher onto bf16_baseline_fix, the original
T=1 / T=2 small-batch heuristic from 192fa39d / 7b7f1ac3 regressed because
the _get_bf16_mtp_config helper was dropped in the conflict resolution
(upstream only had _select_tile_v_for_mtp for ILP=8 tile sizing). This
commit re-adds the ILP=4 variant so the PR ships without the B=4 T=2
regression that would otherwise show in CI.
Changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py:
1. New @cute.kernel gdn_decode_bf16state_mtp_ilp4_kernel. Same math as the
ILP=8 kernel, but processes 4 V-rows/group/iter instead of 8.
- ILP=4: ~48 regs/thread -> ~62 % occupancy (vs ILP=8's ~37 %)
- Minimal signature: single h0_indices (read == write). Split-pool writes
are delegated to ILP=8 by the dispatcher; the ILP=4 path never needs
h0_out_indices, so we skip the extra plumbing.
2. New @cute.jit run_gdn_decode_bf16state_mtp_ilp4 launcher, mirrors the
existing ILP=8 launcher's tile_v/grid/SMEM computation but calls the
new kernel.
3. New _get_bf16_mtp_config(batch_size, seq_len, num_v_heads, v_dim) helper:
- work_units <= 128: (min(16, v_dim), 4) # B<=2 at HV=64
- seq_len == 2 and work_units <= 256: (tile_v, 4) # covers B=4 HV=64
- else: (tile_v, 8)
The T=2 branch compensates for the ILP=8 pipeline stall from the inline
g/beta recompute path (sGB is only populated for T > 2 in the ILP=8
kernel; at small work_units ILP=8's 37 % occupancy cannot hide the
softplus+exp+log+sigmoid latency).
4. gated_delta_rule_mtp dispatcher:
- When output_state_indices is None: pick (tile_v, ilp_rows) via
_get_bf16_mtp_config and route to the matching launcher.
- When output_state_indices is not None: force ilp_rows=8 so the
split-pool write path (h0_out_indices) stays on the ILP=8 kernel
that supports it.
- Cache key now includes ilp_rows so the two launchers don't collide.
Perf verification (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI,
5 warmup + 50 bench iters; cache_intermediate=True, disable_state_update=True
for T>=2; T=1 uses state-update ON):
post-rebase before post ILP=4 re-add
B T (us) (us)
1 1 3.46 3.46
1 2 5.43 5.41
2 1 4.26 4.26
2 2 6.51 6.50
4 1 5.82 5.82
4 2 11.36 <- +17 % regr. 9.68 <- back to reference
8 1 9.25 9.25
8 2 13.95 13.95
16 1 15.06 15.06
16 2 21.50 (wide_vec) 21.50 (wide_vec)
Correctness: 30 / 30 test_gdn_decode_bf16_state_mtp_kernel configs pass at
B in {1, 2, 4, 8, 16}, T in {2, 4, 8}, cache_intermediate in {T, F} via
the official _test_gdn_decode_bf16_state_mtp_kernel helper. No test changes
needed — the existing test exercises the dispatcher and both ILP paths.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Follow-up to PR flashinfer-ai#3127. The shipped wide_vec kernel has a fixed tile_v=128 so its grid is only B*HV CTAs — at B <= 8 (HV=64) that's <= 512 CTAs on a 148-SM B200 (< 3.5 waves), and the dispatcher falls back to the baseline ILP=4/8 path (LDG.E.64). This commit parameterizes tile_v so the grid gains a V-tile dimension, letting the wide_vec kernel (LDG.E.128 on H) reach down into small-batch sizes that were previously starved of SM parallelism. Changes: 1. gdn_decode_bf16_state_wide_vec.py — tile_v is now a per-call constexpr: - Kernel decodes (i_n, i_hv, i_v) from block_idx; v_base offset by `i_v * tile_v`. ROWS_PER_GROUP and ITERS_PER_GROUP promoted from module-level constants to per-kernel constexprs. - Launcher computes grid = B * HV * (V / tile_v) and passes tile_v + num_v_tiles to the kernel. - Public wrapper `gated_delta_rule_mtp_wide_vec` accepts `tile_v: int = 128` (keyword-only default; positional callers unchanged). - Valid tile_v values: {32, 64, 128}. Asserted at the launcher entry. - Subgroup layout unchanged (16 threads x 8 BF16 per thread = LDG.E.128), so every tile_v variant uses the same per-thread memory-op widths. - Bit-exact with the previous fixed-tile_v=128 kernel at tile_v=128 (verified: max_abs_diff = 0.0 over 252 pytest configs). 2. gdn_decode_bf16_state.py — new `_select_wide_vec_tile_v(B, HV, V=128)` helper picks tile_v by work_units = B * HV: work_units >= 1024 -> tile_v = 128 (unchanged: B >= 16 at HV=64) work_units >= 512 -> tile_v = 64 (new: B = 8 at HV=64) work_units >= 128 -> tile_v = 32 (new: B = 2..4 at HV=64) below -> None (baseline ILP=4/8 as before) `_WIDE_VEC_WORK_UNITS_THRESHOLD` lowered 1024 -> 128 (module-scope constant kept for external benchmarks that monkey-patch it; actual dispatch uses the new picker). Dispatcher in `gated_delta_rule_mtp` now calls the picker and passes tile_v into `gated_delta_rule_mtp_wide_vec`. 3. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_wide_vec_mtp_kernel` parametrized over `tile_v in {32, 64, 128}`. num_v_heads restricted to {64} (Qwen3.5 production shape; HV=32 exploratory coverage is not needed for the small-batch path). Total configs: 9 batch x 7 T x 2 cache x 3 tile_v = 378. Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI 5 warmup + 50 bench iters, cache_intermediate=True, disable_state_update=True for T>=2; T=1 unchanged): post- 3127 before post tile_v picker B T (us) (us) speedup 1 2 5.44 5.46 1.00x (picker -> tv=32 ties baseline) 2 2 6.40 6.34 1.01x (picker -> tv=32) 4 2 9.63 8.14 1.18x (picker -> tv=32) 8 2 13.63 12.74 1.07x (picker -> tv=64) 16 2 21.20 21.18 1.00x (tv=128, unchanged) 32 2 37.81 37.89 1.00x (tv=128, unchanged) 64 2 70.69 70.69 1.00x (tv=128, unchanged) Biggest headline: B=4 T=2 goes from 9.63 us to 8.14 us (1.18x) by routing to wide_vec tile_v=32, which was previously dispatched to the baseline ILP=4 kernel. Cumulative against origin/main at B=4 T=2: 11.14 us -> 8.14 us = 1.37x. Correctness: 378 / 378 new pytest configs pass (test_gdn_decode_bf16_state_wide_vec_mtp_kernel). Equivalent 1008-config sanity run across tile_v = {32, 64, 128} x HV = {32, 64} also passed during development (not in CI to keep the default matrix size reasonable). AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…l mode)
The T=1 ILP kernel (`gdn_decode_bf16state_ilp_kernel`) uses 4 warps × 32
threads with vec=4 BF16 per thread (LDG.E.64 / STG.E.64 on H state). NCU
captures at B=16/32 T=1 HV=64 show LSU pipe as the dominant bottleneck
(LSU 38.8–48.4 % vs DRAM 29.0–39.4 %), with L1/TEX at 67–75 %. The shipped
wide_vec kernel uses 8 groups × 16 threads / vec=8 BF16 (LDG.E.128 /
STG.E.128), which halves LSU instruction count and reduces L1 wavefronts.
wide_vec's SMEM-precompute phase runs `ceil(T / NUM_WARPS)` passes — at
T=1 that is 1 pass (only warp 0 does real work, others idle), so it
degenerates gracefully. Correctness: bit-exact at bf16 noise floor across
all 20 T=1 test configs (HV in {32, 64}, B in {1..512}, FP32 reference).
Changes:
1. gdn_decode_bf16_state.py::gated_delta_rule — new T=1 dispatch branch
placed BEFORE the small-batch (`B < ILP_BATCH_THRESHOLD`) MTP redirect
so B=8 HV=64 (work_units = 512) can also reach wide_vec at T=1 and get
the LDG.E.128 win (measured 1.05x at B=8). Gated by ALL of:
- K = V = 128 (wide_vec subgroup layout assumes this).
- Pool mode: `initial_state_indices is not None`. Non-pool direct-state
callers stay on the baseline ILP kernel, so the split-pool test path
(baseline) remains bit-exact with the non-pool reference path that
`test_output_state_indices` relies on.
- Single-pool write: `output_state_indices is None` or identical to
`initial_state_indices` (wide_vec has one indices tensor for R/W).
- tile_v >= 64. At T=1 the wide_vec Phase 0 SMEM-precompute overhead is
fixed per CTA while the main loop shrinks with tile_v. tile_v=32
gives only 1 ILP iter per subgroup, insufficient to amortize
Phase 0. Probe at HV=64 T=1 pool mode: tile_v=32 regresses at B=4
(0.91x); tile_v=64 wins at B=8 (1.05x); tile_v=128 wins at B>=16
(1.04-1.06x). So tile_v=32 is masked out at T=1 only.
2. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_t1_kernel`
parametrization extended to also cover num_v_heads=64 (production
Qwen3.5 shape). Total T=1 CI configs: 10 batch x 2 HV = 20.
Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, state-update ON, no cache,
pool mode with `initial_state_indices = arange(B)`, CUPTI 5 warmup + 50
bench iters). Baseline column monkey-patches `_select_wide_vec_tile_v`
to None so the ILP/MTP kernel path runs:
B baseline wide_vec T=1 speedup
1 3.49us 3.52us 0.99x (MTP ILP=4; wide_vec gate n/a)
2 4.29us 4.26us 1.01x (MTP ILP=4; wide_vec gate n/a)
4 5.76us 5.76us 1.00x (MTP ILP=8; tile_v=32 masked)
8 9.30us 8.83us 1.05x (NEW: wide_vec tile_v=64)
16 15.01us 14.35us 1.05x (wide_vec tile_v=128)
32 26.53us 25.30us 1.05x
64 48.45us 46.53us 1.04x
128 89.70us 86.00us 1.04x
256 172.22us 165.02us 1.04x
512 337.15us 323.10us 1.04x
Peak DRAM SOL at T=1 climbs from 80.6 % (previous ceiling) to 83.9 % at
B=512. All B >= 8 in pool mode gain 1.04-1.05x. Gains are smaller than
the T>=2 win (~1.10-1.50x) because the T=1 ILP kernel was already closer
to the DRAM roofline — LSU relief has less headroom to translate into
speedup.
Scope notes:
- Non-pool (direct-state) T=1 callers: stay on baseline ILP. Routing
them through wide_vec introduces 1-BF16-ULP output differences vs the
baseline-served split-pool path, which fails test_output_state_indices
at atol=1e-3. Since the test's semantic invariant is that split-pool
and non-pool paths produce bit-identical output, both must use the
same kernel; baseline ILP is the existing common denominator.
- Split-pool writes (distinct output_state_indices tensor): stay on
baseline ILP. wide_vec does not yet plumb h0_out_indices.
- B*HV < 512: falls through to the (relocated) small-batch MTP redirect
and baseline ILP=4/8 kernels. Wide_vec at tile_v=32 regresses here
because Phase 0 overhead isn't amortized.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Drops gdn_decode_bf16state_cooprow_kernel (~270 lines) and its launcher run_gdn_decode_bf16state_cooprow (~95 lines) from flashinfer/gdn_kernels/gdn_decode_bf16_state.py. The cooprow path is unreachable: gated_delta_rule routes to wide_vec / ILP / MTP-T1 only, and the dispatch comment explicitly documents that cooprow had known correctness issues at small batch. Backward-compatible name aliases gated_delta_rule_bf16state_cooprow and gated_delta_rule_bf16state_cooprow_mtp are kept (they alias to the live gated_delta_rule and gated_delta_rule_mtp entry points, not to the removed kernel) so external callers using those names continue to work. Net delete: 376 lines. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The previous commit removed the cooprow kernel, which was the only consumer of `cutlass.cute.nvgpu.cpasync`. Drop the now-unused import to keep ruff happy. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The wide_vec module imported _reference_gdn_mtp from gdn_decode_bf16_state_tma_horiz, but (a) the import was already marked noqa: F401 (unused), and (b) gdn_decode_bf16_state_tma_horiz is a working-tree-only scratch module that was never committed. With it absent (e.g. on a fresh checkout) the wide_vec module fails to import, which cascades into pytest collection failure for the GDN decode tests. Drop the unused import. No behavior change; tests now load the wide_vec module successfully. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The module's top-of-file docstring still described the cooprow path
(``Each warp processes ONE V-row at a time``, ``cp.async pipeline with
TILE_V=8 x TILE_K=128 tiles``) even though that kernel is gone. Replace
it with a current-state description that points at the live entry points
and the wide_vec / MTP fallback dispatch. Drop the module-level
constants that only the removed cooprow kernel used:
- TILE_V = 8, TILE_K = 128, NUM_STAGES = 2, NUM_THREADS = 128,
NUM_BLOCKS_PER_STATE = 8 (cp.async pipeline tile sizes)
- MTP_TILE_K = 128 (never referenced; MTP kernels use MTP_NUM_THREADS /
MTP_VEC_SIZE / MTP_ILP_ROWS only)
- ``_compiled_kernels: dict = {}`` (cooprow's compile cache; the live
caches are ``_compiled_kernels_ilp`` and ``_compiled_kernels_mtp``)
No behavior change.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Both public entry points (``gated_delta_rule`` and ``gated_delta_rule_mtp``) already assert ``K == 128 and V == 128`` at the top of the function, so ``_select_wide_vec_tile_v`` cannot be reached with V != 128. Drop the redundant ``if V != 128: return None`` and the ``V`` parameter; the remaining ``K == 128`` checks at the call sites also become redundant and are inlined out. No behavior change. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The BF16 GDN decode kernels are now strictly pool-mode: each batch element reads/writes a slot in a shared ``[pool_size, HV, V, K]`` state pool, indexed by ``initial_state_indices``. Production serving frameworks (sglang, vllm) already use pool mode exclusively; non-pool mode was only exercised by tests. Changes: - ``gated_delta_rule()`` and ``gated_delta_rule_mtp()`` in ``flashinfer/gdn_kernels/gdn_decode_bf16_state.py`` hard-assert ``initial_state_indices is not None``. The dispatchers no longer branch on ``use_pool``, no longer compute ``pool_size = B`` for the non-pool case, and no longer fall back to a synthesized ``cache["default_indices"]`` at runtime. The dummy ``arange(B)`` kept inside the cached kernel state remains, but only as a cute.compile() shape template. - ``gated_delta_rule_decode_pretranspose()`` in ``flashinfer/gdn_decode.py`` keeps its public ``use_pool`` / ``state`` API unchanged. When a caller hits the BF16 branch with non-pool semantics (passing ``state`` instead of ``initial_state``), the wrapper internally treats ``state`` as a pool of size B and synthesizes sequential indices ``arange(B)`` before calling the BF16 kernel. The math is identical (``pool_size = B``, sequential access), so no behavior change for non-pool callers (tests, etc.). The FP32 ``gdn_decode_pretranspose.py`` kernel is unaffected — it continues to support both pool and non-pool natively. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Removes ``gdn_decode_bf16state_ilp_kernel`` (the standalone T=1 ILP=8 baseline) and its launcher / cache / dispatch helpers. After non-pool support was dropped, the T=1 ILP kernel was only reachable in a narrow edge case (HV<32 with B>=ILP_BATCH_THRESHOLD=16, where wide_vec gates out at tile_v<64) — i.e. never fires on Qwen3.5 (HV=64) or any other production GDN shape. The MTP T=1 path (mtp_kernel ILP=8 / mtp_ilp4 ILP=4) covers the same shapes correctly. Removed: - ``gdn_decode_bf16state_ilp_kernel`` @cute.kernel and surrounding section header (~740 LOC) - ``run_gdn_decode_bf16state_ilp`` @cute.jit launcher (~75 LOC) - ``_compiled_kernels_ilp`` cache, ``ILP_BATCH_THRESHOLD`` constant - ``_select_tile_v_for_batch`` helper (only used by ILP) - ``TILE_V_ILP``, ``TILE_K_ILP``, ``NUM_THREADS_ILP``, ``VEC_SIZE_ILP``, ``ILP_ROWS`` module-level constants Dispatcher change in ``gated_delta_rule()``: when wide_vec doesn't fire (split-pool, or B*HV too small at T=1), unconditionally redirects to ``gated_delta_rule_mtp`` instead of branching on ``B < ILP_BATCH_THRESHOLD``. The MTP path picks ILP=4 vs ILP=8 via ``_get_bf16_mtp_config`` based on work_units. Net: ``gdn_decode_bf16_state.py`` shrinks from 3046 to 2295 lines (~1011 deletions, 40 insertions; mostly the ILP kernel body). AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…rnels After the BF16 GDN kernel layer became pool-only, three direct-call kernel tests still passed an ``[B, HV, V, K]`` state tensor without ``initial_state_indices`` and tripped the new pool-mode assertion. Fix by treating the state tensor as a pool of size B and passing sequential indices ``arange(B)``. The math is identical to the previous non-pool semantics (pool_size = B, sequential access), so coverage of the underlying kernel behavior is preserved. Updated tests: - ``_test_gdn_decode_bf16_state_kernel`` (used by ``test_gdn_decode_bf16_state_kernel``) - ``_test_gdn_decode_bf16_state_t1_kernel`` (used by ``test_gdn_decode_bf16_state_t1_kernel``) - ``test_pretranspose_api_uses_gdn_decode_bf16_state`` AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
The BF16 GDN kernels are pool-only post-cleanup. The bench harness was
calling the kernel without indices, which (a) errored at T=1 with
--update-state once the pool-only assertion was added, and (b) had no
way to drive the split-pool dispatch at all.
Adds:
- ``--pool-mode {single,split}`` argparse flag (default: single)
- ``bench_gdn_decode_bf16_state`` accepts ``pool_mode``: allocates a
``[2*B, HV, V, K]`` pool under split, synthesizes
``output_state_indices = arange(B, 2*B)`` so write slots are distinct
from read slots; under single it allocates ``[B, HV, V, K]`` and
passes ``output_state_indices=None`` (read==write).
- ``gdn_decode_bf16_state_wrapper`` plumbs ``output_state_indices``
through to both the T=1 and T>1 entry points; also passes
``initial_state_indices`` at T=1 (was previously omitted, which is
what caused the ``--update-state`` regression).
- ``intermediate_states_buffer`` stays sized to ``[B, T, HV, V, K]``
in both modes — the kernel keys it by READ index ``cache_idx`` and
read indices are ``arange(B)`` regardless of pool mode.
Fixes the post-cleanup ``--update-state`` regression and unlocks
split-pool benchmarking. No kernel changes; baseline numbers should
match prior runs for single-pool, and split-pool exercises the
existing ``gdn_decode_bf16state_mtp_kernel`` (ILP=8) fallback path.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Wide_vec was previously single-pool only — at the dispatcher we routed split-pool callers (output_state_indices != initial_state_indices) to the slower mtp_kernel ILP=8 fallback. Per the apr27 baseline run on B200 HV=64 K=V=128, that costs 5.9-14.9% across the marquee shapes: Shape single (wide_vec) split (mtp_kernel) overhead T=1 B=8 8.67 us 9.18 us +5.9% T=1 B=32 25.22 us 28.99 us +14.9% T=2 B=8 12.72 us 13.76 us +8.2% T=2 B=32 38.06 us 42.58 us +11.9% T=4 B=8 18.67 us 21.04 us +12.7% T=4 B=32 62.29 us 69.20 us +11.1% Adds: - ``h0_out_indices: cute.Tensor`` parameter to ``gdn_wide_vec_kernel`` (and ``_run_wide_vec`` launcher). - ``write_cache_idx`` derived from ``h0_out_indices[i_n]`` with the same negative-index null-buffer redirect as the read side. - Pre-computed write-side tiles ``ht_w0..ht_w3`` at ``flat_write_state_idx = write_cache_idx * HV + i_hv``. The final state writeback now stores via these write tiles. ``cute.local_tile`` is metadata-only so constructing both views every iteration costs nothing when single-pool callers reuse the same indices. - Intermediate-state cache continues to key by READ index (cache_idx); it represents per-request cached state and is independent of where the final committed state lives. - Python entry ``gated_delta_rule_mtp_wide_vec`` now accepts ``output_state_indices`` and forwards it to the kernel. When None, defaults to ``initial_state_indices`` so single-pool callers don't need to change. - Dispatcher in ``gated_delta_rule`` and ``gated_delta_rule_mtp`` forwards ``output_state_indices`` and drops the ``output_state_indices is None`` gate that previously routed split callers to the ILP=8 fallback. Verified on B200 with a B=16 T=4 HV=64 micro-test: split-pool produces bit-identical output to single-pool, the read slots [0..B) are preserved, the write slots [B..2B) hold the same final state, and single-pool still updates the read slots in place. The legacy ``gdn_decode_bf16state_mtp_kernel`` is still reachable when wide_vec gates out (work_units<128 at T>=2 with split-pool) — the next commit extends ``gdn_decode_bf16state_mtp_ilp4_kernel`` to cover that case so mtp_kernel can be removed. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…rnel Previously the ILP=4 kernel had a minimal signature (single ``h0_indices`` for both read and write), and the dispatcher in ``gated_delta_rule_mtp`` forced ILP=8 whenever ``output_state_indices`` was non-None. This left the work_units<128 split-pool case (B=1 at HV=64) on the slower ILP=8 path even though ILP=4 has higher occupancy at that shape. Mirrors the wide_vec split-pool change: - Adds ``h0_out_indices`` parameter to ``gdn_decode_bf16state_mtp_ilp4_kernel`` and the ``run_gdn_decode_bf16state_mtp_ilp4`` launcher. - Computes ``write_cache_idx`` with the same ``< 0 -> slot 0`` redirect. - Builds write-side tile views ``hta_w..htd_w`` at ``flat_write_state_idx`` and uses them in the final-state writeback. - Dispatcher in ``gated_delta_rule_mtp`` drops the ``if output_state_indices is None`` branch that forced ILP=8 — the config picker is now pool-mode-agnostic. - Updates the cache build / runtime call to plumb ``h0_out_indices`` / ``output_state_indices``. The intermediate-state cache continues to key by READ index (cache_idx), matching wide_vec's semantics (per-request cache, independent of write destination). Verified on B200 with a B=1 T=4 HV=64 micro-test (the smallest work_units shape, exercises the ilp4 path): split-pool produces bit-identical output to single-pool, read slot preserved, write slot matches single-pool result, single-pool still in-place updates. After this commit, ``gdn_decode_bf16state_mtp_kernel`` (ILP=8) is no longer reachable from any shape — every split-pool / single-pool combination at every (B, T) routes to wide_vec or mtp_ilp4. The next commit removes mtp_kernel and its launcher. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After wide_vec gained split-pool support and mtp_ilp4 likewise,
``gdn_decode_bf16state_mtp_kernel`` (ILP=8) is no longer reachable
from any dispatch path. The remaining MTP-fallback shapes (work_units
< 128, T=1 small batch with tile_v < 64) all benefit from ILP=4's
higher occupancy, so ``_get_bf16_mtp_config`` collapses to always
returning ``ilp_rows = 4``.
Removed:
- ``gdn_decode_bf16state_mtp_kernel`` @cute.kernel (~924 LOC) and its
surrounding section header.
- ``run_gdn_decode_bf16state_mtp`` @cute.jit launcher (~95 LOC) and
surrounding section header.
- The ILP=8 branch of ``_get_bf16_mtp_config`` (T=2-stall heuristic
is unnecessary now — wide_vec covers T=2 work_units >= 128 and
ILP=4 covers the rest).
- The ``MTP_ILP_ROWS = 8`` constant (only the removed kernel used it).
- The ``ilp_rows == 4 / else`` branching in the MTP cache build and
runtime call sites — both branches now go to mtp_ilp4.
Updated:
- Module top-of-file docstring: drops mtp_kernel from the dispatch
list and notes that split-pool is supported natively.
- ``_select_tile_v_for_mtp`` docstring: tile_v multiple of
``4 * MTP_ILP4_ROWS`` (= 16), not the old ``MTP_ILP_ROWS * 4``.
Verified on B200 with a 7-shape sweep covering each dispatch path
(B in {1, 4, 8, 32} x T in {1, 4}): output bit-identical between
single-pool and split-pool, split-pool read slots stay pristine.
Net: ``gdn_decode_bf16_state.py`` shrinks from 2295 to 1221 lines
(-1074 LOC, ~47% smaller). The BF16 surface is now exactly two
@cute.kernel definitions: ``gdn_wide_vec_kernel`` (the fast path)
and ``gdn_decode_bf16state_mtp_ilp4_kernel`` (the small-batch
fallback) — both pool-only, both split-pool capable.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Adds ``test_gdn_decode_bf16_state_mtp_split_pool`` (12 cases) which
calls ``gated_delta_rule_mtp`` directly with split-pool semantics
(``output_state_indices != initial_state_indices``) and verifies:
- Output is bit-equivalent (within bf16 noise) to the single-pool
dispatch on the same inputs.
- Read slots [0..B) stay pristine in split mode.
- When state-update is enabled (cache_intermediate_states=False),
the split write slots [B..2B) hold the same final state that the
single-pool dispatch wrote into the read slots.
Sweeps three batch sizes that hit each kernel:
- B=1 -> mtp_ilp4 (work_units=64 < 128, wide_vec gates out)
- B=8 -> wide_vec (work_units=512)
- B=32 -> wide_vec (work_units=2048)
Plus T in {2, 4} and cache_intermediate_states in {True, False}.
Complements the existing ``test_output_state_indices`` (which
exercises T=1 split-pool through the wrapper) by covering the T>=2
direct-MTP path.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…BF16 file
Now that the BF16 surface is exactly two ``@cute.kernel``s
(``gdn_wide_vec_kernel`` and ``gdn_decode_bf16state_mtp_ilp4_kernel``),
keeping them in two separate Python files added more friction than
isolation. Both call into the same dispatcher, share the FMA helpers,
and need the same ``NUM_SMS`` / ``_USE_PACKED_FMA`` runtime constants.
Merges the wide_vec file into ``flashinfer/gdn_kernels/gdn_decode_bf16_state.py``:
- Wide_vec layout constants (``LANES_PER_ROW``, ``ELEMS_PER_LANE``,
``NUM_WARPS``, ``NUM_THREADS``, ``NUM_GROUPS``, ``ILP_ROWS``) inlined
next to the MTP constants. Drops the unused ``TILE_K = 128``
documentation constant from the wide_vec block.
- ``gdn_wide_vec_kernel`` and its ``_run_wide_vec`` launcher land in
their own section in the merged file. Section ordering is now:
wide_vec kernel + launcher, then mtp_ilp4 kernel + launcher.
- ``_compiled_kernels_wide_vec`` cache lives next to
``_compiled_kernels_mtp``.
- ``gated_delta_rule_mtp_wide_vec`` Python entry sits between the
helper functions and the public dispatchers ``gated_delta_rule`` /
``gated_delta_rule_mtp``. Drops the two cross-file lazy imports
(``from .gdn_decode_bf16_state_wide_vec import …``) that the
dispatchers used to do — same-module references now.
- Drops the duplicated ``fma_pair_mul`` / ``fma_pair`` definitions and
``NUM_SMS`` / ``_GPU_MAJOR`` / ``_USE_PACKED_FMA`` constants that
appeared in both files; the merged file uses the existing main-file
definitions.
Test ``test_gdn_decode_bf16_state_wide_vec_mtp_kernel`` updated to
import ``gated_delta_rule_mtp_wide_vec`` from the merged location.
The two old files (1221 + 799 = 2020 lines) become one file at 1973
lines (~47 LOC saved from dedup of imports, FMA helpers, runtime
constants, and section headers).
No behavior change. Full BF16 / wide_vec test sweep verified prior to
merge; cross-shape correctness (B in {1, 8, 32} × T in {2, 4})
re-checked on the merged file (single == split bit-equivalent at
bf16 noise floor). Pre-commit clean.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
… PR flashinfer-ai#3145) The ``intermediate_states_buffer`` is BATCH-scoped — shape ``[B, T, HV, V, K]`` — but both the wide_vec kernel (line 1115) and the mtp_ilp4 kernel (line 621) were indexing it with ``cache_idx * T * HV + i_t * HV + i_hv`` where ``cache_idx`` is the POOL slot from ``initial_state_indices[i_n]``. When ``pool_size > B`` (every realistic serving config) and ``initial_state_indices`` points at slots ``>= B`` (e.g. middle of a 1024-slot pool while servicing a B=32 batch), ``cache_idx * T * HV`` exceeds the buffer's ``B * T * HV`` extent and the ``cute.local_tile`` write goes off the end of the cache buffer -> ``cudaErrorIllegalAddress`` or silent memory corruption. This is the same bug upstream PR flashinfer-ai#3145 fixed in the now-removed ``gdn_decode_bf16state_mtp_kernel``; both surviving BF16 kernels inherited the incorrect pattern. Fix: - ``gdn_decode_bf16state_mtp_ilp4_kernel``: ``flat_idx = i_n * T * HV + ...`` (was ``cache_idx * T * HV + ...``). - ``gdn_wide_vec_kernel``: same. - Dispatcher (both ``gated_delta_rule_mtp`` and ``gated_delta_rule_mtp_wide_vec``): assert ``intermediate_states_buffer.shape[0] == B`` and reshape using ``B`` rather than ``buffer_size``. Also updates the comment / docstring to call out batch-scoped semantics explicitly. Adds ``test_gdn_decode_bf16_state_mtp_pool_larger_than_batch`` (12 cases) which parametrizes ``pool_size_multiplier in {1, 4}`` and ``batch_size in {1, 8, 32}`` and ``seq_len in {2, 4}`` so both the ilp4 path (B=1) and the wide_vec path (B=8/32) are exercised with pool indices pointing at the upper half of a 4*B-slot pool. Verified the test catches the bug: re-introducing the ``cache_idx * T * HV`` form makes the test fail with ``cudaErrorIllegalAddress``; reverting the line makes it pass again. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Several comments still described the old dispatch landscape (mtp_kernel ILP=8 path, ``B*HV >= 1024`` wide_vec gate, pool-scoped intermediate buffer) even though those have been replaced this session. No code changes. Updated: - ``gated_delta_rule`` docstring: no longer claims the MTP fallback picks "ILP=4 or ILP=8 via _get_bf16_mtp_config" (now always ILP=4). Tightened the split-pool note — wide_vec handles split-pool natively on either dispatch path. - ``gated_delta_rule_mtp_wide_vec`` docstring: gate threshold is ``B*HV >= 128`` (tile_v=32) up through ``>= 1024`` (tile_v=128), not just ``>= 1024``. - ``gated_delta_rule_mtp`` docstring: ``intermediate_states_buffer`` is ``[B, T, HV, V, K]``, batch-scoped (was incorrectly documented as ``[pool_size, T, HV, V, K]``). Calls out the OOB-fix invariant for future readers. - Comment in ``gated_delta_rule_mtp`` dispatcher: drops "wide_vec hardcodes TILE_V=128 so we gate on K==V==128 too" — TILE_V is now configurable. Drops "mtp_kernel (ILP=8)" reference (kernel removed). - Comment in ``gated_delta_rule`` T=1 fallback: drops "dispatches to mtp_kernel (ILP=8) or mtp_ilp4_kernel" (only mtp_ilp4 left). - Comment in ``gated_delta_rule_mtp`` post-wide_vec branch: only the ILP=4 MTP path is reachable now. - Kernel param comments for ``intermediate_states``: shape ``[B * T * HV, V, K]``, not ``[pool_size * T * HV, V, K]``. Two ``ILP=8`` references remain by design (line 92: explains the mtp_ilp4 design lineage relative to the original ILP=8 kernel; line 1512: explains why ILP=4's higher occupancy beats ILP=8 at the small-batch fallback shapes wide_vec doesn't hit). Both are useful historical context, not stale claims. AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
…s (same_pool Constexpr)
Recovers the T=1 large-B single-pool regression vs apr22 documented in
``results/reports_apr27/regression_analysis.md`` (+4.6 % to +6.7 %
across B in {16, 32, 64, 128, 256} at T=1 with state-update on, the
production Qwen3.5 decode shape).
Root cause: split-pool support added the ability to write to a
distinct pool slot via ``h0_out_indices``. The dispatcher passes
identical read/write indices in single-pool mode, so the stores go
to the same slot they used to — but nvcc cannot prove this at compile
time (``output_state_indices is initial_state_indices`` is a runtime
identity), so the SASS materialises the extra LDG.32 +
negative-redirect compare + IMAD + 4×local_tile base-pointer ops per
V-iter even when they're guaranteed-equivalent to the read-side
arithmetic. In a DRAM-saturated kernel (T=1 B>=16 hits 74 % of peak
DRAM) this stretches the critical path by ~5-7 %.
Fix: a ``Constexpr[bool] same_pool`` kernel parameter, set by the
dispatcher to ``output_state_indices is None or output_state_indices
is initial_state_indices``. When True, the kernel aliases write-side
state to the read side at compile time:
- ``write_cache_idx = cache_idx``
- ``flat_write_state_idx = flat_state_idx``
- ``ht_w* = ht*`` (write tiles alias read tiles)
nvcc DCEs the dead branch in the same_pool=True compile path; SASS
matches apr22's pre-split-pool kernel exactly. When False
(split-pool), the kernel uses the existing distinct write-side
machinery — bit-identical behaviour.
Applied symmetrically to both BF16 kernels:
- ``gdn_wide_vec_kernel`` (the production hot path)
- ``gdn_decode_bf16state_mtp_ilp4_kernel`` (the small-batch fallback)
Public API is unchanged: callers (sglang, etc.) pass ``output_state_indices``
or omit it as before; the dispatcher computes ``same_pool`` and selects
the compile variant. JIT cache footprint roughly doubles per shape that
exercises both pool modes — for sglang's typical single-pool-only call
pattern, only the same_pool=True variant ever compiles.
Verified post-fix on B200: bench in single-pool, --warmup 5 --iters 50,
T=1 with --update-state, T>=2 with --cache-intermediate-states.
Numbers in microseconds; Δ is post-fix vs the apr22 prototype reference
at commit e7bee6ad of the same branch.
Full BS × T grid (single-pool, post-fix vs apr22):
| T=1 | T=2 | T=3 | T=4 | T=5 | T=6 | T=7 | T=8
----+--------+--------+--------+--------+--------+--------+--------+--------
apr22 prototype (us):
B=1 3.52 5.42 5.86 6.69 8.83 9.79 10.58 11.39
B=4 5.79 8.13 9.79 11.39 14.53 16.10 17.89 19.55
B=8 9.28 12.67 15.90 18.91 23.55 26.42 30.59 33.95
B=16 14.30 21.39 27.20 33.36 40.93 47.79 54.53 61.20
B=32 25.30 38.13 49.94 62.08 77.62 89.94 103.22 117.01
B=64 46.53 70.85 94.34 117.62 146.05 171.41 197.98 222.24
B=128 86.00 134.91 180.27 225.54 278.67 328.54 379.52 427.78
B=256 165.02 262.70 350.88 440.14 543.20 640.32 741.31 837.17
post-fix (us):
B=1 3.39 5.28 5.70 6.53 8.82 9.63 10.40 11.30
B=4 5.98 7.87 9.62 11.26 14.48 16.22 17.76 19.44
B=8 8.54 12.48 15.49 18.75 23.17 26.27 29.74 33.33
B=16 14.34 21.55 27.15 33.87 41.38 48.67 55.50 63.14
B=32 25.15 38.08 50.03 62.40 77.63 91.39 104.83 118.78
B=64 46.50 70.88 93.92 118.18 150.19 175.70 200.34 230.40
B=128 85.90 134.83 180.61 226.81 288.00 340.53 391.90 443.36
B=256 164.98 261.76 351.49 441.97 559.82 665.23 764.11 868.05
Δ post-fix vs apr22 (%):
B=1 -3.7% -2.6% -2.7% -2.4% -0.1% -1.6% -1.7% -0.8%
B=4 +3.3% -3.2% -1.7% -1.1% -0.3% +0.7% -0.7% -0.6%
B=8 -8.0% -1.5% -2.6% -0.8% -1.6% -0.6% -2.8% -1.8%
B=16 +0.3% +0.7% -0.2% +1.5% +1.1% +1.8% +1.8% +3.2%
B=32 -0.6% -0.1% +0.2% +0.5% 0.0% +1.6% +1.6% +1.5%
B=64 -0.1% 0.0% -0.4% +0.5% +2.8% +2.5% +1.2% +3.7%
B=128 -0.1% -0.1% +0.2% +0.6% +3.3% +3.7% +3.3% +3.6%
B=256 0.0% -0.4% +0.2% +0.4% +3.1% +3.9% +3.1% +3.7%
Headline: the T=1 regression is recovered. **All T=1 cells now within
+/- 3.3 % of apr22**, and the previously-regressed T=1 B>=16 cells
(was +4.6 % .. +6.7 %) are within +/- 1.5 %.
T=2/T=3/T=4 stay clean across the full B sweep (within +/- 3.2 %).
T=5..T=8 large-B residual (+2.5 % .. +3.9 % at B>=64) is unchanged by
this fix and not from the write-side machinery: with cache=ON the
writeback block is compile-time disabled, nvcc was already DCE'ing
the write-tile views, and ``same_pool`` has nothing extra to elide.
Likely from the kernel-entry LDG.32 of ``h0_out_indices`` (always
runs, regardless of caching). Defer to a follow-up if it matters;
production hot path (T=1) is recovered.
Split-pool path verified within +/- 0.2 % of pre-fix split-pool
numbers across (B, T) in {(8, 1), (8, 4), (32, 1), (32, 4), (256, 1),
(256, 4)} — split callers still hit the explicit write-side code via
the same_pool=False compile variant.
Bit-equivalence check across 8 (B, T) shapes (B in {1, 8, 32, 256} x
T in {1, 4}): single-pool output == split-pool output to bf16 noise
floor in every case.
Full pytest sweep: 513 passed, 0 failed in 18m18s
(``test_gdn_decode_bf16_state*``, ``test_decode_kernel_pretranspose_pool*``,
``test_output_state_indices*``).
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
dc87e93 to
e1f6c53
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/conftest.py`:
- Around line 143-149: The traversal that checks for CUDA OOM only follows one
exception link with "exc = exc.__cause__ or exc.__context__", so wrapped OOMs on
the other branch can be missed; update the logic in tests/conftest.py (the loop
that inspects "exc" and calls is_cuda_oom_error_str and checks isinstance(...,
torch.cuda.OutOfMemoryError)) to traverse both __cause__ and __context__
branches (e.g., use a stack/queue or set to visit exceptions
breadth/depth-first, pushing exc.__cause__ and exc.__context__ when present)
until all linked exceptions are inspected and return True if any is an OOM.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 34a6f2f1-b813-434d-bf58-e5d66218c68a
📥 Commits
Reviewing files that changed from the base of the PR and between 7266b5e and ec76b2aac8ec8b2c7428af8ac1c68bd0453bf85e.
📒 Files selected for processing (1)
tests/conftest.py
| while exc is not None: | ||
| if isinstance(exc, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str( | ||
| str(exc) | ||
| ): | ||
| return True | ||
| exc = exc.__cause__ or exc.__context__ | ||
| return False |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
def old_walk(exc):
while exc is not None:
if "out of memory" in str(exc).lower() and "cuda" in str(exc).lower():
return True
exc = exc.__cause__ or exc.__context__
return False
def new_walk(exc):
seen, stack = set(), [exc]
while stack:
cur = stack.pop()
if id(cur) in seen:
continue
seen.add(id(cur))
s = str(cur).lower()
if "cuda" in s and "out of memory" in s:
return True
if cur.__cause__ is not None:
stack.append(cur.__cause__)
if cur.__context__ is not None:
stack.append(cur.__context__)
return False
ctx = RuntimeError("CUDA out of memory while formatting tensor diff")
cause = RuntimeError("wrapper cause without oom text")
top = RuntimeError("top-level runtime")
top.__cause__ = cause
top.__context__ = ctx
print("old_walk:", old_walk(top)) # expected False (missed)
print("new_walk:", new_walk(top)) # expected True (detected)
PYRepository: flashinfer-ai/flashinfer
Length of output: 99
Traverse both __cause__ and __context__ branches to detect OOM in wrapped exceptions.
At line 148, the code uses exc = exc.__cause__ or exc.__context__, which follows only one link. When both are present, an OOM error in the non-selected branch will be missed, causing a test to fail instead of skip. The exception chain from torch.testing.assert_close can wrap inner OOM errors in multiple branches, so both must be examined.
Proposed fix
def _is_oom_in_chain(exc: BaseException) -> bool:
- # torch.testing.assert_close wraps the inner OOM in a RuntimeError("Comparing ...")
- # whose str() does not mention "out of memory" — only __cause__ does.
- while exc is not None:
- if isinstance(exc, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(
- str(exc)
- ):
- return True
- exc = exc.__cause__ or exc.__context__
+ # torch.testing.assert_close may wrap inner exceptions; inspect the full chain.
+ seen: set[int] = set()
+ stack: list[BaseException] = [exc]
+ while stack:
+ cur = stack.pop()
+ cur_id = id(cur)
+ if cur_id in seen:
+ continue
+ seen.add(cur_id)
+ if isinstance(cur, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(
+ str(cur)
+ ):
+ return True
+ if cur.__cause__ is not None:
+ stack.append(cur.__cause__)
+ if cur.__context__ is not None:
+ stack.append(cur.__context__)
return False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/conftest.py` around lines 143 - 149, The traversal that checks for CUDA
OOM only follows one exception link with "exc = exc.__cause__ or
exc.__context__", so wrapped OOMs on the other branch can be missed; update the
logic in tests/conftest.py (the loop that inspects "exc" and calls
is_cuda_oom_error_str and checks isinstance(..., torch.cuda.OutOfMemoryError))
to traverse both __cause__ and __context__ branches (e.g., use a stack/queue or
set to visit exceptions breadth/depth-first, pushing exc.__cause__ and
exc.__context__ when present) until all linked exceptions are inspected and
return True if any is an OOM.
5cd8d87 to
4c7c6c9
Compare
|
/bot run |
|
/bot run |
|
/bot run |
|
/bot run |
The wide_vec MTP test materializes two full-tensor float32 upcasts of
[B, T, HV, V, K] BF16 in torch.testing.assert_close. At B=128, T=8 (or
B=256, T>=2) each upcast is ~2 GiB and torch.isclose OOMs. PyTorch
re-wraps the OutOfMemoryError as RuntimeError("Comparing\n\n..."), whose
str() lacks "CUDA"/"out of memory", so conftest's OOM-skip filter misses
it and the test reports as a failure rather than a skip.
Chunk the assertion along the batch axis (~256 MiB float32 peak per
chunk). _test_gdn_decode_bf16_state_mtp_kernel is shared with the
non-wide_vec MTP test, so both benefit.
Fixes failures on RTX 5090 (32 GiB):
test_gdn_decode_bf16_state_wide_vec_mtp_kernel[
bfloat16-128-16-16-64-128-{7,8}-True-32]
test_gdn_decode_bf16_state_wide_vec_mtp_kernel[
bfloat16-256-16-16-64-128-{2,3}-True-32]
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Head branch was pushed to by a user without write access
|
/bot run |
…-offset overflow
Same root cause as the original report — the CuTe-DSL pool+indices kernels
compute the per-slot element offset (cache_idx * HV + i_hv) * stride[0]
(or pool_idx * stride[0] for fp32) using Int32, so once the product
exceeds INT32_MAX the multiplication wraps to a negative offset and the
load/store hits an unmapped global address (CUDA error: an illegal
memory access was encountered). Discovered while integrating
gated_delta_rule_decode_pretranspose into vLLM's GDN decode path for
Qwen3.5-class models.
Backends fixed (HV=32, V=K=128 unless noted):
fp32 pretranspose:
gdn_decode_kernel_{small,big}_batch_pretranspose
pool_idx * stride[0] overflows at pool_idx >= 3972
(vLLM padded slot stride 540672)
bf16 small-batch fallback (B*HV <= 128):
gdn_decode_bf16state_mtp_ilp4_kernel
(cache_idx * HV + i_hv) * V * K overflows at cache_idx >= 4096
bf16 production fast path
(B*HV >= 128 at T>=2, B*HV >= 512 at T=1):
gdn_wide_vec_kernel
(cache_idx * HV + i_hv) * V * K overflows at cache_idx >= 4096
(i_n * T * HV + i_t * HV + i_hv) overflows at B*T*HV >= 131072
* V * K (intermediate-states e.g. B>=256 at T=8 HV=64
cache writeback) with cache_intermediate_states=True
Fix: widen pool_idx / out_pool_idx (pretranspose) and cache_idx /
write_cache_idx (bf16) to Int64 immediately after they are read from
the indices tensors; widen the per-call batch index used for the
intermediate-states cache write (i_n) to Int64 in both bf16 kernels so
that flat_idx = i_n * T * HV + i_t * HV + i_hv inherits Int64 even when
B * T * HV crosses 131072. The downstream flat_state_idx,
flat_write_state_idx and flat_idx all stay Int64, so the offset
multiplications inside cute.local_tile / h0_source[(...)] cannot wrap.
Note on PR scope vs. PR flashinfer-ai#3147 (Ameyn/wide vec t1, merged on main after
this branch was cut): flashinfer-ai#3147 deleted gdn_decode_bf16state_mtp_kernel and
replaced the bf16 surface with gdn_decode_bf16state_mtp_ilp4_kernel
(small-batch fallback) + gdn_wide_vec_kernel (production hot path).
Both surviving kernels reproduce the same Int32 cache_idx offset bug as
the deleted kernel, and gdn_wide_vec_kernel additionally has an Int32
flat_idx offset bug on the intermediate-states writeback that's
reachable at production batch sizes (B >= 256 with MTP T=8, HV=64).
This commit fixes all three sites at once.
Regression test (tests/gdn/test_decode_pretranspose_noncontiguous_pool.py)
covers:
* test_decode_pretranspose_pool_int64_offset[3973, 8191] — fp32
vLLM-padded pool (~8.6 / 17.7 GB) at the pool_idx threshold.
* test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196] — bf16
contiguous pool (~4.3 GB) at the cache_idx threshold; at B=1 HV=32
the dispatcher selects gdn_decode_bf16state_mtp_ilp4_kernel, which
has the identical overflow site and identical Int64 fix as
gdn_wide_vec_kernel. The wide_vec / intermediate-states fixes are
structurally identical to the ilp4 fix exercised here.
Both bf16 / fp32 tests compare the pool path against a gather + direct-state
reference (numerical correctness, not just non-crashing) and assert the
in-place state update matches. VRAM-based skip when free memory is
insufficient.
AI-assisted (Cursor / Claude).
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Summary
Replaces the legacy
gdn_decode_bf16state_cooprow_kerneland thegdn_decode_bf16state_mtp_kernel(ILP=8) with a newgdn_wide_vec_kernel(LDG.E.128 / STG.E.128 fast path) plus asmall-batch
mtp_ilp4fallback. Drops ~1900 LOC of dead/unused code,adds split-pool support (#2905-compatible) to both surviving BF16
kernels, and ships the OOB fix mirroring upstream PR #3145 — for the
BF16 kernels that survived the cleanup.
Supersedes #3118. That PR's perf delta (T=1 per-call overhead +
pool+padding for the ILP kernel) is the first commit on this branch
(
8a6e9819).What changes
gdn_wide_vec_kernel— 128 threads/CTA = 8 groups× 16 threads, vec=8 BF16 → LDG.E.128 / STG.E.128, ILP=4 V-rows per
thread. Configurable
tile_v ∈ {32, 64, 128}so the kernel coverssmall/medium/large
B*HVwork-unit sizes uniformly.the production serving contract). Wrapper
gated_delta_rule_decode_pretransposeauto-promotes legacy non-poolcallers internally — public API unchanged.
kernels (
gdn_wide_vec_kernel,gdn_decode_bf16state_mtp_ilp4_kernel)natively support
output_state_indices != initial_state_indices,with bit-equivalent single-pool behavior selected at compile time
via
Constexpr[bool] same_poolfor zero-overhead dispatch.intermediate_statesis indexedby the per-call batch index
i_n(not the pool-scopedcache_idx),so the buffer can be sized
[B, T, HV, V, K]as production callersexpect. Regression test catches the bug; pre-fix triggers
cudaErrorIllegalAddressin <2 s.Removed (~1900 LOC of dead code)
gdn_decode_bf16state_cooprow_kernel(~280 LOC)gdn_decode_bf16state_ilp_kernel(~740 LOC)gdn_decode_bf16state_mtp_kernel(ILP=8) (~940 LOC)End-state BF16 surface = 2
@cute.kernels in one file:gdn_wide_vec_kernel— production hot pathgdn_decode_bf16state_mtp_ilp4_kernel— small-batch fallbackBoth pool-only, both split-pool capable, both indexed batch-scoped.
Speedup vs previous baseline
Baseline = pre-wide_vec dispatch (the
mtp_kernelILP=8 path, capturedon this same branch by monkey-patching
_select_wide_vec_tile_vtoreturn
Nonefor every shape). Same harness, same hardware, sameconfig — so the comparison isolates the kernel-level speedup that
wide_vec + the cleanup deliver.
Setup: B200, HV=64, K=V=128, BF16, qk_l2norm=ON, warmup=5, iters=50,
T=1 invoked with
--update-state, T≥2 invoked with--cache-intermediate-states. Kernel time in microseconds (CUPTI).Speedup (×, baseline / post-PR)
Headline
Sustained DRAM bandwidth post-PR (TB/s, 8 TB/s peak on B200)
Post-PR peaks at 6.57 TB/s = 82 % of B200 peak DRAM (T=1 B=256 production decode shape).
Split-pool
With wide_vec now supporting split-pool natively, split-pool
matches single-pool to within ±1 % at every measured shape.
Tests
Including: 477 existing BF16/wide_vec/pool tests, 12 new split-pool MTP tests, 12 new OOB regression tests covering
pool_size_multiplier ∈ {1, 4}×B ∈ {1, 8, 32}×T ∈ {2, 4}, 12 wrapper-level split-pool tests.Files changed (4)
flashinfer/gdn_decode.py— wrapper auto-promotes BF16 non-pool → poolflashinfer/gdn_kernels/gdn_decode_bf16_state.py— wide_vec inlined; dead kernels removed; split-pool plumbing; OOB fix; same_pool DCEtests/gdn/test_decode_delta_rule.py— split-pool + OOB regression testsbenchmarks/bench_gdn_decode.py—--pool-mode {single,split}flag🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
--pool-modeoption to benchmark tool for configuring state pool allocation (singleorsplitmodes).Tests