optimize gdn decode bf16 state kernel for mtp with caching. #3127
optimize gdn decode bf16 state kernel for mtp with caching. #3127ameynaik-hub wants to merge 4 commits intoflashinfer-ai:mainfrom
Conversation
…to ILP kernel
Two related changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py and
the T=1 dispatch in flashinfer/gdn_decode.py:
1) Per-call Python overhead fix
Move torch.arange / torch.empty / from_dlpack / get_device_capability out
of the steady-state call path. These were previously called on every
invocation of gated_delta_rule and gated_delta_rule_mtp, adding ~3.75 us
of CUPTI-visible overhead per call at small BS. All default-tensor
allocation and dlpack conversion is now done once, inside the
`cache_key not in _compiled_kernels_*` block, and cached alongside the
compiled kernel. Steady-state calls pass raw torch tensors directly to
the tvm-ffi callable. Adds module-level _USE_PACKED_FMA in place of
per-call torch.cuda.get_device_capability().
2) Pool + padding support on the ILP kernel
gdn_decode_bf16state_ilp_kernel (the T=1 fast path for B >= 16) now
accepts h0_indices and h0_out_indices, matching the MTP kernel's
signature. Negative indices redirect to pool slot 0 (null buffer);
writes go to a separate flat_write_idx so input and output pool slots
can differ. The ILP launcher and gated_delta_rule wrapper thread the
new tensors through; the T=1 dispatch in flashinfer/gdn_decode.py is
collapsed so pool+indices T=1 calls no longer detour through the
heavier MTP kernel.
Design choice: kernel always takes indices (no constexpr switch).
Benchmark config: Qwen3.5-397B-A17B linear attention
(num_q_heads=16, num_k_heads=16, num_v_heads=64, head_size=128, bf16, qk_l2norm ON)
GPU: NVIDIA B200
Command:
python benchmarks/bench_gdn_decode.py \
--batch-size 1 4 8 16 32 64 128 256 512 \
--num-q-heads 16 --num-k-heads 16 --num-v-heads 64 \
--head-size 128 --dtype bfloat16 --warmup 20 --iters 200
Bf16State column results (us):
BS | time
1 | 3.71
4 | 5.89
8 | 9.18
16 | 14.98
32 | 26.56
64 | 48.24
128 | 89.66
256 | 172.35
512 | 337.60
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR extends FlashInfer's GDN decoding with pooled-state indexing support and introduces a wide-vector BF16 MTP kernel variant. Changes include modifying existing BF16 kernels to accept separate read/write pool indices, adding a new ILP=4 MTP kernel variant, introducing a wide-vector BF16 MTP kernel with intermediate-state caching, and updating dispatch logic to conditionally route between kernel backends. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~80 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations for the BF16 GDN MTP decode kernels. Key improvements include a dynamic configuration mechanism to select optimal tiling and ILP (4 vs 8) based on workload size, and the addition of a new wide-vector kernel variant that utilizes 128-bit loads (8 BF16 elements) to maximize throughput for large batch sizes. The dispatcher logic has been updated to handle these new paths while reducing overhead by caching compiled kernels and minimizing redundant DLPack conversions. Feedback is provided regarding parameter naming consistency in the new wide-vector kernel implementation.
| q: cute.Tensor, | ||
| k: cute.Tensor, | ||
| v: cute.Tensor, | ||
| b_gate: cute.Tensor, |
There was a problem hiding this comment.
For consistency with other kernels in this file and the project (e.g., gdn_decode_bf16state_mtp_kernel), it would be clearer to name this parameter b instead of b_gate. The public-facing API gated_delta_rule_mtp_wide_vec already uses b.
This change should be propagated to the function body and the _run_wide_vec wrapper as well.
| b_gate: cute.Tensor, | |
| b: cute.Tensor, |
…stic fix
Three coordinated changes to the BF16 GDN MTP decode path on B200:
1. New kernel `gated_delta_rule_mtp_wide_vec`
(flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py):
- 128 threads/CTA organised as 8 groups of 16 threads (baseline: 4 of 32)
- vec=8 BF16 per thread -> LDG.128 / STG.128 (baseline: vec=4 -> .64)
- ILP=4 V-rows/thread, 4-stage butterfly over 16-lane subgroup
- No TMA, no persistent CTAs; the win is purely wider LSU transactions
- Peaks at 77.6 % DRAM SOL at B=256, T=2, HV=64
2. Auto-dispatcher inside `gated_delta_rule_mtp`
(flashinfer/gdn_kernels/gdn_decode_bf16_state.py):
- New module-level `_WIDE_VEC_WORK_UNITS_THRESHOLD = 1024`
- Routes to wide_vec when `B*HV >= 1024 and T >= 2 and K == V == 128`
- T=1 and small-batch callers keep the existing baseline path verbatim,
so the public API is source-and-ABI compatible.
3. B=4 HV=64 T=2 heuristic fix in `_get_bf16_mtp_config`:
- Baseline's T=2 code path (gdn_decode_bf16_state.py) recomputes g/beta
inline instead of reading from sGB (only T>2 pre-populates sGB). The
inline softplus+exp+log+sigmoid stalls the ILP=8 pipeline at small
work_units.
- Fix: when `seq_len == 2 and work_units <= 256`, return (tile_v, 4)
instead of (tile_v, 8). ILP=4 gives ~62 % occupancy vs ILP=8's ~37 %,
covering the recompute latency.
- Measured: B=4 HV=64 T=2: 11.20 us -> 9.63 us (1.17x).
Benchmark - Qwen3.5-397B-A17B Gated DeltaNet shape
(B200, HV=64, H_Q=H_K=16, K=V=128, BF16, cache_intermediate=True,
disable_state_update=True; T=1 uses state-update ON, no intermediate caching;
measured via benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_state with
CUPTI, 5 warmup + 50 bench iters per cell):
Wall-time (us), dispatcher's best-of-both output
(= baseline for B<=8, wide_vec for B>=16 at this HV; T=1 always baseline):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 21.20 27.23 33.73 41.23 47.89 55.15 62.11
32 26.37 37.81 49.98 62.93 78.46 91.20 103.73 117.58
64 47.76 70.69 93.76 117.92 146.42 172.80 197.82 225.44
128 90.38 135.17 180.46 226.99 278.78 329.41 378.18 432.17
256 173.65 262.98 351.06 440.75 542.37 641.63 739.60 846.40
512 337.98 516.24 691.10 869.02 ERROR ERROR ERROR ERROR
Wall-time (us), pre-session baseline (forced by monkey-patching
_WIDE_VEC_WORK_UNITS_THRESHOLD to 10**9):
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ------- ------- ------- ------- ------- ------- ------- -------
1 3.46 5.44 5.79 6.67 8.83 9.79 10.61 11.50
2 4.25 6.40 7.14 8.27 10.66 11.97 13.07 14.18
4 5.79 9.63 10.53 12.51 16.03 18.11 20.32 22.22
8 8.96 13.63 17.02 21.14 26.11 30.02 34.16 37.92
16 15.22 23.71 30.02 37.84 46.82 54.50 61.95 69.89
32 26.37 42.56 54.75 69.01 85.41 99.74 114.13 128.50
64 47.76 79.18 101.38 129.02 159.44 187.44 216.08 244.03
128 90.38 149.34 193.09 247.38 305.87 361.33 417.20 473.81
256 173.65 289.12 375.31 481.41 596.72 705.98 816.64 930.16
512 337.98 568.32 741.25 952.97 ERROR ERROR ERROR ERROR
Speedup (baseline / dispatcher). <=1.00x means dispatcher keeps baseline.
B=4 T=2 "1.00x" reflects the *new* heuristic (ilp=4) already in effect for
both columns; the pre-heuristic-fix baseline at that cell was 11.20 us
(1.16x slower than the post-fix 9.63 us).
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
2 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
4 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
8 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x 1.00x
16 1.00x 1.12x 1.10x 1.12x 1.14x 1.14x 1.12x 1.13x
32 1.00x 1.13x 1.10x 1.10x 1.09x 1.09x 1.10x 1.09x
64 1.00x 1.12x 1.08x 1.09x 1.09x 1.08x 1.09x 1.08x
128 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
256 1.00x 1.10x 1.07x 1.09x 1.10x 1.10x 1.10x 1.10x
512 1.00x 1.10x 1.07x 1.10x ERROR ERROR ERROR ERROR
DRAM SOL (%, of 8.0 TB/s B200 peak) at the dispatcher's output:
B \ T T=1 T=2 T=3 T=4 T=5 T=6 T=7 T=8
----- ----- ----- ----- ----- ----- ----- ----- -----
1 15.3 14.6 18.4 20.0 18.1 19.1 20.1 20.9
2 24.9 24.9 29.8 32.2 30.0 31.2 32.6 33.9
4 36.6 33.1 40.4 42.6 39.9 41.2 42.0 43.2
8 47.3 46.8 50.0 50.4 49.0 49.7 50.0 50.6
16 55.7 60.1 62.5 63.2 62.0 62.3 61.9 61.8
32 64.3 67.4 68.1 67.7 65.2 65.5 65.8 65.3
64 70.9 72.1 72.6 72.3 69.9 69.1 69.0 68.1
128 75.0 75.4 75.5 75.1 73.4 72.5 72.2 71.1
256 78.1 77.6 77.6 77.3 75.5 74.4 73.8 72.6
512 80.2 79.0 78.8 78.4 ERROR ERROR ERROR ERROR
Peak SOL: 80.2 % at B=512 T=1 (T=1 path unchanged; reflects the earlier
`7b7f1ac3` heuristic). Peak SOL on wide_vec-dispatched cells: 79.0 % at
B=512 T=2. Baseline-only peak at the same cells was ~71 % (B=512 T=2).
B=512 T>=5 hits the known cudaErrorIllegalAddress from
results/2026-04-03/benchmark_results.md (unrelated; pre-existing).
Correctness: 282 / 282 pytest configs pass on B=1..256, T=2..8, HV in {32, 64},
cache_intermediate_states in {True, False} via the official
`_test_gdn_decode_bf16_state_mtp_kernel` helper, reused across both
`test_gdn_decode_bf16_state_mtp_kernel` (baseline, unchanged) and the new
`test_gdn_decode_bf16_state_wide_vec_mtp_kernel` (monkey-patched symbol).
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
After rebasing the wide_vec dispatcher onto bf16_baseline_fix, the original
T=1 / T=2 small-batch heuristic from 192fa39d / 7b7f1ac3 regressed because
the _get_bf16_mtp_config helper was dropped in the conflict resolution
(upstream only had _select_tile_v_for_mtp for ILP=8 tile sizing). This
commit re-adds the ILP=4 variant so the PR ships without the B=4 T=2
regression that would otherwise show in CI.
Changes to flashinfer/gdn_kernels/gdn_decode_bf16_state.py:
1. New @cute.kernel gdn_decode_bf16state_mtp_ilp4_kernel. Same math as the
ILP=8 kernel, but processes 4 V-rows/group/iter instead of 8.
- ILP=4: ~48 regs/thread -> ~62 % occupancy (vs ILP=8's ~37 %)
- Minimal signature: single h0_indices (read == write). Split-pool writes
are delegated to ILP=8 by the dispatcher; the ILP=4 path never needs
h0_out_indices, so we skip the extra plumbing.
2. New @cute.jit run_gdn_decode_bf16state_mtp_ilp4 launcher, mirrors the
existing ILP=8 launcher's tile_v/grid/SMEM computation but calls the
new kernel.
3. New _get_bf16_mtp_config(batch_size, seq_len, num_v_heads, v_dim) helper:
- work_units <= 128: (min(16, v_dim), 4) # B<=2 at HV=64
- seq_len == 2 and work_units <= 256: (tile_v, 4) # covers B=4 HV=64
- else: (tile_v, 8)
The T=2 branch compensates for the ILP=8 pipeline stall from the inline
g/beta recompute path (sGB is only populated for T > 2 in the ILP=8
kernel; at small work_units ILP=8's 37 % occupancy cannot hide the
softplus+exp+log+sigmoid latency).
4. gated_delta_rule_mtp dispatcher:
- When output_state_indices is None: pick (tile_v, ilp_rows) via
_get_bf16_mtp_config and route to the matching launcher.
- When output_state_indices is not None: force ilp_rows=8 so the
split-pool write path (h0_out_indices) stays on the ILP=8 kernel
that supports it.
- Cache key now includes ilp_rows so the two launchers don't collide.
Perf verification (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI,
5 warmup + 50 bench iters; cache_intermediate=True, disable_state_update=True
for T>=2; T=1 uses state-update ON):
post-rebase before post ILP=4 re-add
B T (us) (us)
1 1 3.46 3.46
1 2 5.43 5.41
2 1 4.26 4.26
2 2 6.51 6.50
4 1 5.82 5.82
4 2 11.36 <- +17 % regr. 9.68 <- back to reference
8 1 9.25 9.25
8 2 13.95 13.95
16 1 15.06 15.06
16 2 21.50 (wide_vec) 21.50 (wide_vec)
Correctness: 30 / 30 test_gdn_decode_bf16_state_mtp_kernel configs pass at
B in {1, 2, 4, 8, 16}, T in {2, 4, 8}, cache_intermediate in {T, F} via
the official _test_gdn_decode_bf16_state_mtp_kernel helper. No test changes
needed — the existing test exercises the dispatcher and both ILP paths.
AI-assisted by Claude Code.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
192fa39 to
16f6f14
Compare
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py`:
- Around line 487-490: The final-state writeback currently skips updating
initial_state_source when caching is enabled (via effective_disable_final True),
which can silently drop the final state; update the conditional logic around
disable_state_update/cache_intermediate_states so that when disable_state_update
is False and cache_intermediate_states (or intermediate_states_buffer is
present) you explicitly copy intermediate_states_buffer[:, T-1] back into
initial_state_source before skipping the normal write-back (or alternatively
force wide-vec caching into verify-only mode). Apply the same fix to the other
identical final-state write-back site that follows the same pattern (the later
block using effective_disable_final/cutlass.const_expr).
- Around line 239-252: The softplus branch currently computes exp(beta_x_pre)
unconditionally which can overflow; clamp beta_x_pre to at most
softplus_threshold (e.g., beta_x_pre_clamped = min(beta_x_pre,
softplus_threshold)) before calling cute.exp/cute.log to avoid inf, then compute
exp_beta_x_pre and softplus_val_pre from beta_x_pre_clamped and keep the
existing use_softplus_pre mask and final softplus_x_pre selection (references:
softplus_beta, x_pre, beta_x_pre, softplus_threshold, exp_beta_x_pre,
softplus_val_pre, use_softplus_pre, softplus_x_pre, cute.exp, cute.log).
- Around line 618-630: The code currently creates a contiguous copy of
intermediate_states_buffer via intermediate_states =
intermediate_states.contiguous(), which means kernel writes to the copy and the
original buffer remains stale; change the logic in the cache_intermediate_states
branch to either assert that intermediate_states_buffer.is_contiguous() (and
raise/complain if not) before doing the reshape, or if you allow non-contiguous
inputs, allocate a contiguous working tensor and after kernel execution
copy/write the updated data back into intermediate_states_buffer (preserving
dtype torch.bfloat16 and the original shape), using the same HV_val/V_val/K_val
layout so caller’s buffer[:, T-1] contains the final state; refer to
intermediate_states_buffer, intermediate_states and cache_intermediate_states to
locate where to implement the assert or the copy-back.
- Around line 640-720: The cache_key currently omits per-tensor device, dtype
and stride info and also caches a reusable default_output that gets returned
across calls; update the cache key construction (the variable cache_key used to
index _compiled_kernels_wide_vec) to include q.device.type and q.device.index
plus the dtype and .stride() (as tuples) for every tensor passed into
from_dlpack/cute.compile (h0_source, intermediate_states, q, k, v, a, b, A_log,
dt_bias, output and initial_state_indices), and stop storing/reusing a single
default_output in the cache — only cache the compiled kernel and
default_indices, and always allocate a fresh output buffer when output is None
before calling from_dlpack/cute.compile so each call gets its own output tensor.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Around line 3076-3078: The code converts output_state_indices to torch.int32
but forgets to normalize initial_state_indices, causing compiled kernels (which
expect int32 index tensors) to fail when callers pass int64; update the logic
around use_pool/initial_state_indices to check if initial_state_indices is not
None and, if so, call initial_state_indices =
initial_state_indices.to(torch.int32) (mirror the output_state_indices
conversion), and apply the same change at the second occurrence mentioned (the
block around lines 3331-3332) so both code paths use int32 index tensors for the
compiled kernels.
- Around line 2678-2680: The comments in gdn_decode_bf16_state.py next to the
sQ, sK, and sGB size calculations use the Unicode multiplication sign "×" which
triggers RUF003; update those comment strings to use plain ASCII "x" (e.g.,
change "T × (K+8) × 4 bytes" to "T x (K+8) x 4 bytes") for the sQ, sK, and sGB
lines so linting passes and the intent remains the same; also scan the
surrounding comments in the same block for any other "×" characters and replace
them with "x".
- Around line 2554-2563: The flattened index into intermediate_states_buffer
uses T instead of cache_steps, causing out-of-bounds writes when cache_steps >
T; update the computation of flat_idx to use cache_steps (i.e., flat_idx =
cache_idx * cache_steps * HV + i_t * HV + i_hv) wherever flat_idx is computed
(reference symbols: intermediate_states_buffer, cache_idx, cache_steps, T, HV,
i_t, i_hv, flat_idx) and ensure cache_steps is declared/passed as a cutlass
constexpr through the kernel/launcher/cache key so compiled kernels differ when
the buffer step dimension changes; make the same change at the other occurrence
noted (the second block around the ILP=4 handling).
- Around line 3131-3133: The cached compile-time buffer (cache["output"]) is
being returned when output is None, causing later calls to overwrite previously
returned tensors; update the logic in the decode function(s) so that
cache["output"] is used only to infer shape/for compilation but you always
allocate and return a fresh tensor (e.g., create a new torch.empty with the same
shape/dtype/device as cache["output"] when output is None) and do not return the
cached object itself; apply the same change to the analogous spots that create
default_output/default_indices (the occurrences around the variables
default_output, default_indices and the other similar blocks at the other noted
locations).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 447d5233-bc08-40ba-89f2-f250cc53cb5f
📒 Files selected for processing (4)
flashinfer/gdn_decode.pyflashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pytests/gdn/test_decode_delta_rule.py
| beta_x_pre = softplus_beta * x_pre | ||
| exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True) | ||
| softplus_val_pre = (cutlass.Float32(1.0) / softplus_beta) * cute.log( | ||
| cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True | ||
| ) | ||
| use_softplus_pre = ( | ||
| cutlass.Float32(1.0) | ||
| if beta_x_pre <= softplus_threshold | ||
| else cutlass.Float32(0.0) | ||
| ) | ||
| softplus_x_pre = ( | ||
| use_softplus_pre * softplus_val_pre | ||
| + (cutlass.Float32(1.0) - use_softplus_pre) * x_pre | ||
| ) |
There was a problem hiding this comment.
Avoid evaluating the overflowing softplus branch.
This computes exp(beta_x_pre) before applying the threshold. For large positive x_pre, exp_beta_x_pre can become inf, and the arithmetic mask can produce 0 * inf -> NaN.
Proposed fix
- exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True)
- softplus_val_pre = (cutlass.Float32(1.0) / softplus_beta) * cute.log(
- cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True
- )
- use_softplus_pre = (
- cutlass.Float32(1.0)
- if beta_x_pre <= softplus_threshold
- else cutlass.Float32(0.0)
- )
- softplus_x_pre = (
- use_softplus_pre * softplus_val_pre
- + (cutlass.Float32(1.0) - use_softplus_pre) * x_pre
- )
+ if beta_x_pre <= softplus_threshold:
+ exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True)
+ softplus_x_pre = (
+ cutlass.Float32(1.0) / softplus_beta
+ ) * cute.log(cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True)
+ else:
+ softplus_x_pre = x_pre🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 239 -
252, The softplus branch currently computes exp(beta_x_pre) unconditionally
which can overflow; clamp beta_x_pre to at most softplus_threshold (e.g.,
beta_x_pre_clamped = min(beta_x_pre, softplus_threshold)) before calling
cute.exp/cute.log to avoid inf, then compute exp_beta_x_pre and softplus_val_pre
from beta_x_pre_clamped and keep the existing use_softplus_pre mask and final
softplus_x_pre selection (references: softplus_beta, x_pre, beta_x_pre,
softplus_threshold, exp_beta_x_pre, softplus_val_pre, use_softplus_pre,
softplus_x_pre, cute.exp, cute.log).
| # Final state write-back: skip when caching (inter[T-1] already has it) | ||
| if cutlass.const_expr( | ||
| not disable_state_update and not cache_intermediate_states | ||
| ): |
There was a problem hiding this comment.
Preserve state-update semantics when caching is enabled.
With intermediate_states_buffer set, effective_disable_final=True skips updating initial_state_source even when disable_state_update=False. If this path is reached through the existing MTP dispatcher, callers can silently stop seeing the final state in the input pool; either copy buffer[:, T-1] back for update-enabled calls or gate wide-vec caching to verify-only mode.
Also applies to: 631-635
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 487 -
490, The final-state writeback currently skips updating initial_state_source
when caching is enabled (via effective_disable_final True), which can silently
drop the final state; update the conditional logic around
disable_state_update/cache_intermediate_states so that when disable_state_update
is False and cache_intermediate_states (or intermediate_states_buffer is
present) you explicitly copy intermediate_states_buffer[:, T-1] back into
initial_state_source before skipping the normal write-back (or alternatively
force wide-vec caching into verify-only mode). Apply the same fix to the other
identical final-state write-back site that follows the same pattern (the later
block using effective_disable_final/cutlass.const_expr).
| h0_source = initial_state_source.reshape(pool_size * HV_val, V_val, K_val) | ||
|
|
||
| cache_intermediate_states = intermediate_states_buffer is not None | ||
| if cache_intermediate_states: | ||
| buffer_size = intermediate_states_buffer.shape[0] | ||
| cache_steps = intermediate_states_buffer.shape[1] | ||
| assert cache_steps >= T_val | ||
| assert intermediate_states_buffer.dtype == torch.bfloat16 | ||
| intermediate_states = intermediate_states_buffer.reshape( | ||
| buffer_size * cache_steps * HV_val, V_val, K_val | ||
| ) | ||
| if not intermediate_states.is_contiguous(): | ||
| intermediate_states = intermediate_states.contiguous() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file and the specific lines in context
cat -n flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py | sed -n '610,650p'Repository: flashinfer-ai/flashinfer
Length of output: 1789
🏁 Script executed:
# Let's also check the entire function to understand the context
wc -l flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 128
🏁 Script executed:
# Search for the function containing these lines and understand the kernel launch pattern
ast-grep --pattern 'def $FUNC_NAME($$$) {
$$$
intermediate_states = intermediate_states_buffer.reshape($$$)
$$$
}'Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for where intermediate_states is used after this reshape
rg "intermediate_states" flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py -A 3 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 4702
Reject non-contiguous intermediate_states_buffer or ensure writeback.
The kernel receives a reshaped view of intermediate_states_buffer. If that view is non-contiguous, .contiguous() allocates a copy; the kernel then updates the copy while intermediate_states_buffer remains stale. When cache_intermediate_states=True, the docstring requires the caller to read the final state from buffer[:, T-1], but non-contiguous inputs break this contract silently. Either assert contiguity before reshape, or copy results back to the original buffer after kernel launch.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 618 -
630, The code currently creates a contiguous copy of intermediate_states_buffer
via intermediate_states = intermediate_states.contiguous(), which means kernel
writes to the copy and the original buffer remains stale; change the logic in
the cache_intermediate_states branch to either assert that
intermediate_states_buffer.is_contiguous() (and raise/complain if not) before
doing the reshape, or if you allow non-contiguous inputs, allocate a contiguous
working tensor and after kernel execution copy/write the updated data back into
intermediate_states_buffer (preserving dtype torch.bfloat16 and the original
shape), using the same HV_val/V_val/K_val layout so caller’s buffer[:, T-1]
contains the final state; refer to intermediate_states_buffer,
intermediate_states and cache_intermediate_states to locate where to implement
the assert or the copy-back.
| cache_key = ( | ||
| "v3_mtp_bf16", | ||
| B_val, | ||
| T_val, | ||
| H_val, | ||
| HV_val, | ||
| K_val, | ||
| V_val, | ||
| pool_size, | ||
| effective_disable_final, | ||
| cache_intermediate_states, | ||
| use_qk_l2norm_in_kernel, | ||
| scale, | ||
| softplus_beta, | ||
| softplus_threshold, | ||
| use_packed_fma, | ||
| ) | ||
| if cache_key not in _compiled_kernels_wide_vec: | ||
| default_indices = torch.arange(B_val, dtype=torch.int32, device=q.device) | ||
| default_output = torch.empty( | ||
| B_val, T_val, HV_val, V_val, device=q.device, dtype=q.dtype | ||
| ) | ||
|
|
||
| if initial_state_indices is None: | ||
| initial_state_indices = default_indices | ||
| if output is None: | ||
| output = default_output | ||
|
|
||
| h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) | ||
| inter_ = from_dlpack(intermediate_states, assumed_align=32, enable_tvm_ffi=True) | ||
| q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) | ||
| k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) | ||
| v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) | ||
| a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) | ||
| b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) | ||
| A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) | ||
| dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) | ||
| o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) | ||
| h0_idx_ = from_dlpack( | ||
| initial_state_indices, assumed_align=32, enable_tvm_ffi=True | ||
| ) | ||
|
|
||
| _compiled_kernels_wide_vec[cache_key] = { | ||
| "compiled": cute.compile( | ||
| _run_wide_vec, | ||
| h_, | ||
| inter_, | ||
| A_log_, | ||
| a_, | ||
| dt_bias_, | ||
| q_, | ||
| k_, | ||
| v_, | ||
| b_, | ||
| o_, | ||
| h0_idx_, | ||
| softplus_beta, | ||
| softplus_threshold, | ||
| scale, | ||
| HV_val, | ||
| B_val, | ||
| T_val, | ||
| H_val, | ||
| K_val, | ||
| V_val, | ||
| use_qk_l2norm_in_kernel, | ||
| effective_disable_final, | ||
| cache_intermediate_states, | ||
| use_packed_fma, | ||
| stream, | ||
| options="--enable-tvm-ffi --generate-line-info --opt-level 3", | ||
| ), | ||
| "default_indices": default_indices, | ||
| "output": default_output, | ||
| } | ||
|
|
||
| cache = _compiled_kernels_wide_vec[cache_key] | ||
| if initial_state_indices is None: | ||
| initial_state_indices = cache["default_indices"] | ||
| if output is None: | ||
| output = cache["output"] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find and read the relevant file
fd "gdn_decode_bf16_state_wide_vec.py"Repository: flashinfer-ai/flashinfer
Length of output: 124
🏁 Script executed:
# Read the file to examine the code around lines 640-720
wc -l flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 128
🏁 Script executed:
# Read lines 640-720 to see the code in question
sed -n '640,720p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 2828
🏁 Script executed:
# Read more context to see the function signature and parameters
sed -n '550,650p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 3387
🏁 Script executed:
# Check the end of the function to see how the cache is used
sed -n '720,736p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 346
🏁 Script executed:
# Check if the function has any decorators
sed -n '570,595p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 990
🏁 Script executed:
# Let's check if this is part of a high-level API by looking at imports and checking if decorators are used
head -50 flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 2216
🏁 Script executed:
# Check if other similar functions in the codebase use decorators
rg "@flashinfer_api|@backend_requirement" flashinfer/gdn_kernels/ -A 3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if these decorators exist anywhere in the codebase
rg "@flashinfer_api|@backend_requirement" flashinfer/ | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1002
🏁 Script executed:
# Check the decorator definitions in utils.py
rg "def flashinfer_api|def backend_requirement" flashinfer/utils.py -A 10Repository: flashinfer-ai/flashinfer
Length of output: 587
🏁 Script executed:
# Search for flashinfer_api decorator definition
rg "def flashinfer_api" flashinfer/ -A 10Repository: flashinfer-ai/flashinfer
Length of output: 1390
🏁 Script executed:
# Check where this function is exported/used as a public API
rg "gated_delta_rule_mtp_wide_vec" flashinfer/ -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 1567
🏁 Script executed:
# Let me verify the actual problem by checking if different devices/dtypes could use the same cache
sed -n '600,660p' flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.pyRepository: flashinfer-ai/flashinfer
Length of output: 2255
🏁 Script executed:
# Check if similar kernels in gdn_kernels have these issues
ls -la flashinfer/gdn_kernels/Repository: flashinfer-ai/flashinfer
Length of output: 703
🏁 Script executed:
# Check if similar patterns exist in other files in gdn_kernels
rg "_compiled_kernels" flashinfer/gdn_kernels/ -B 5 -A 15 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 8718
Fix cache key to include full tensor metadata and allocate fresh outputs on every call.
The compiled kernel depends on tensor device, dtype, and strides via DLPack conversion (lines 671–677), but the cache key omits device type/index, all tensor dtypes, and strides. This causes incorrect cache hits across devices or dtype variants.
More critically, when output=None, the code caches and reuses the same default_output tensor (lines 660–665, 710–712, 718–720). A second call with matching cache key returns the same tensor object, causing the kernel to overwrite the first call's result.
Extend the cache key to include q.device.type, q.device.index, dtypes of all tensors passed to cute.compile(), initial_state_indices.dtype, and strides of all tensors. Always allocate a fresh output buffer unless the caller explicitly provides one.
Sketch of safer structure
+ if output is None:
+ output = torch.empty(
+ B_val, T_val, HV_val, V_val, device=q.device, dtype=q.dtype
+ )
+
+ indices_dtype = (
+ initial_state_indices.dtype if initial_state_indices is not None else torch.int32
+ )
cache_key = (
"v3_mtp_bf16",
+ q.device.type,
+ q.device.index,
+ q.dtype,
+ k.dtype,
+ v.dtype,
+ a.dtype,
+ b.dtype,
+ A_log.dtype,
+ dt_bias.dtype,
+ initial_state_source.dtype,
+ indices_dtype,
+ output.dtype,
+ q.stride(),
+ k.stride(),
+ v.stride(),
+ a.stride(),
+ b.stride(),
+ initial_state_source.stride(),
B_val,
T_val,
H_val,
@@
- "output": default_output,
}
@@
- if output is None:
- output = cache["output"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state_wide_vec.py` around lines 640 -
720, The cache_key currently omits per-tensor device, dtype and stride info and
also caches a reusable default_output that gets returned across calls; update
the cache key construction (the variable cache_key used to index
_compiled_kernels_wide_vec) to include q.device.type and q.device.index plus the
dtype and .stride() (as tuples) for every tensor passed into
from_dlpack/cute.compile (h0_source, intermediate_states, q, k, v, a, b, A_log,
dt_bias, output and initial_state_indices), and stop storing/reusing a single
default_output in the cache — only cache the compiled kernel and
default_indices, and always allocate a fresh output buffer when output is None
before calling from_dlpack/cute.compile so each call gets its own output tensor.
| if cutlass.const_expr(cache_intermediate_states): | ||
| for i in cutlass.range_constexpr(vec_size): | ||
| r_hb4_0[i] = cutlass.BFloat16(r_h[0, i]) | ||
| r_hb4_1[i] = cutlass.BFloat16(r_h[1, i]) | ||
| r_hb4_2[i] = cutlass.BFloat16(r_h[2, i]) | ||
| r_hb4_3[i] = cutlass.BFloat16(r_h[3, i]) | ||
|
|
||
| if cutlass.const_expr(cache_intermediate_states): | ||
| flat_idx = cache_idx * T * HV + i_t * HV + i_hv | ||
| ita = cute.local_tile( |
There was a problem hiding this comment.
Use cache_steps, not T, for intermediate-state pool stride.
intermediate_states_buffer allows shape[1] >= T, but the ILP=4 kernel flattens with cache_idx * T * HV. For cache_steps > T and cache_idx > 0, this writes into the wrong pool slot. Pass cache_steps as a constexpr and use it in the flattened index.
Proposed fix shape
def gdn_decode_bf16state_mtp_ilp4_kernel(
@@
T: cutlass.Constexpr[int],
+ cache_steps: cutlass.Constexpr[int],
@@
- flat_idx = cache_idx * T * HV + i_t * HV + i_hv
+ flat_idx = cache_idx * cache_steps * HV + i_t * HV + i_hvApply the same cache_steps parameter through the launcher/cache key so compiled kernels are distinct when the buffer step dimension changes.
Also applies to: 3339-3348
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2554 - 2563,
The flattened index into intermediate_states_buffer uses T instead of
cache_steps, causing out-of-bounds writes when cache_steps > T; update the
computation of flat_idx to use cache_steps (i.e., flat_idx = cache_idx *
cache_steps * HV + i_t * HV + i_hv) wherever flat_idx is computed (reference
symbols: intermediate_states_buffer, cache_idx, cache_steps, T, HV, i_t, i_hv,
flat_idx) and ensure cache_steps is declared/passed as a cutlass constexpr
through the kernel/launcher/cache key so compiled kernels differ when the buffer
step dimension changes; make the same change at the other occurrence noted (the
second block around the ILP=4 handling).
| 4 * T * (K + 8) # sQ: T × (K+8) × 4 bytes (shared, one copy) | ||
| + 4 * T * (K + 8) # sK: same | ||
| + 4 * T * 2 # sGB: T × 2 × 4 bytes (shared) |
There was a problem hiding this comment.
Replace Unicode multiplication signs in comments.
Ruff flags the × characters here with RUF003; use plain x to keep lint clean.
Proposed fix
- 4 * T * (K + 8) # sQ: T × (K+8) × 4 bytes (shared, one copy)
+ 4 * T * (K + 8) # sQ: T x (K+8) x 4 bytes (shared, one copy)
+ 4 * T * (K + 8) # sK: same
- + 4 * T * 2 # sGB: T × 2 × 4 bytes (shared)
+ + 4 * T * 2 # sGB: T x 2 x 4 bytes (shared)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| 4 * T * (K + 8) # sQ: T × (K+8) × 4 bytes (shared, one copy) | |
| + 4 * T * (K + 8) # sK: same | |
| + 4 * T * 2 # sGB: T × 2 × 4 bytes (shared) | |
| 4 * T * (K + 8) # sQ: T x (K+8) x 4 bytes (shared, one copy) | |
| 4 * T * (K + 8) # sK: same | |
| 4 * T * 2 # sGB: T x 2 x 4 bytes (shared) |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 2678-2678: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 2678-2678: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 2680-2680: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 2680-2680: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 2678 - 2680,
The comments in gdn_decode_bf16_state.py next to the sQ, sK, and sGB size
calculations use the Unicode multiplication sign "×" which triggers RUF003;
update those comment strings to use plain ASCII "x" (e.g., change "T × (K+8) × 4
bytes" to "T x (K+8) x 4 bytes") for the sQ, sK, and sGB lines so linting passes
and the intent remains the same; also scan the surrounding comments in the same
block for any other "×" characters and replace them with "x".
| use_pool = initial_state_indices is not None | ||
| if output_state_indices is not None and output_state_indices.dtype != torch.int32: | ||
| output_state_indices = output_state_indices.to(torch.int32) |
There was a problem hiding this comment.
Normalize initial_state_indices to int32 too.
The kernels are compiled with int32 index tensors, but only output_state_indices is converted. Passing the usual PyTorch int64 indices for initial_state_indices can fail the compiled call or read the wrong slots.
Proposed fix
use_pool = initial_state_indices is not None
+ if initial_state_indices is not None and initial_state_indices.dtype != torch.int32:
+ initial_state_indices = initial_state_indices.to(torch.int32)
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)
@@
+ if initial_state_indices is not None and initial_state_indices.dtype != torch.int32:
+ initial_state_indices = initial_state_indices.to(torch.int32)
if output_state_indices is not None and output_state_indices.dtype != torch.int32:
output_state_indices = output_state_indices.to(torch.int32)Also applies to: 3331-3332
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 3076 - 3078,
The code converts output_state_indices to torch.int32 but forgets to normalize
initial_state_indices, causing compiled kernels (which expect int32 index
tensors) to fail when callers pass int64; update the logic around
use_pool/initial_state_indices to check if initial_state_indices is not None
and, if so, call initial_state_indices = initial_state_indices.to(torch.int32)
(mirror the output_state_indices conversion), and apply the same change at the
second occurrence mentioned (the block around lines 3331-3332) so both code
paths use int32 index tensors for the compiled kernels.
| default_output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) | ||
| default_indices = torch.arange(B, dtype=torch.int32, device=q.device) | ||
|
|
There was a problem hiding this comment.
Don’t return the cached compile-time output buffer.
When output is None, same-shape calls reuse cache["output"], so a later call overwrites any previously returned tensor. Keep the cached tensor only for compilation and allocate a fresh result unless the caller explicitly passes output.
Proposed fix
- "output": default_output,
+ "_compile_output": default_output,
@@
if output is None:
- output = cache["output"]
+ output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)
@@
- "output": default_output,
+ "_compile_output": default_output,
@@
- "output": default_output,
+ "_compile_output": default_output,
@@
if output is None:
- output = cache["output"]
+ output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype)Also applies to: 3192-3193, 3433-3435, 3531-3532
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py` around lines 3131 - 3133,
The cached compile-time buffer (cache["output"]) is being returned when output
is None, causing later calls to overwrite previously returned tensors; update
the logic in the decode function(s) so that cache["output"] is used only to
infer shape/for compilation but you always allocate and return a fresh tensor
(e.g., create a new torch.empty with the same shape/dtype/device as
cache["output"] when output is None) and do not return the cached object itself;
apply the same change to the analogous spots that create
default_output/default_indices (the occurrences around the variables
default_output, default_indices and the other similar blocks at the other noted
locations).
|
/bot run |
Follow-up to PR flashinfer-ai#3127. The shipped wide_vec kernel has a fixed tile_v=128 so its grid is only B*HV CTAs — at B <= 8 (HV=64) that's <= 512 CTAs on a 148-SM B200 (< 3.5 waves), and the dispatcher falls back to the baseline ILP=4/8 path (LDG.E.64). This commit parameterizes tile_v so the grid gains a V-tile dimension, letting the wide_vec kernel (LDG.E.128 on H) reach down into small-batch sizes that were previously starved of SM parallelism. Changes: 1. gdn_decode_bf16_state_wide_vec.py — tile_v is now a per-call constexpr: - Kernel decodes (i_n, i_hv, i_v) from block_idx; v_base offset by `i_v * tile_v`. ROWS_PER_GROUP and ITERS_PER_GROUP promoted from module-level constants to per-kernel constexprs. - Launcher computes grid = B * HV * (V / tile_v) and passes tile_v + num_v_tiles to the kernel. - Public wrapper `gated_delta_rule_mtp_wide_vec` accepts `tile_v: int = 128` (keyword-only default; positional callers unchanged). - Valid tile_v values: {32, 64, 128}. Asserted at the launcher entry. - Subgroup layout unchanged (16 threads x 8 BF16 per thread = LDG.E.128), so every tile_v variant uses the same per-thread memory-op widths. - Bit-exact with the previous fixed-tile_v=128 kernel at tile_v=128 (verified: max_abs_diff = 0.0 over 252 pytest configs). 2. gdn_decode_bf16_state.py — new `_select_wide_vec_tile_v(B, HV, V=128)` helper picks tile_v by work_units = B * HV: work_units >= 1024 -> tile_v = 128 (unchanged: B >= 16 at HV=64) work_units >= 512 -> tile_v = 64 (new: B = 8 at HV=64) work_units >= 128 -> tile_v = 32 (new: B = 2..4 at HV=64) below -> None (baseline ILP=4/8 as before) `_WIDE_VEC_WORK_UNITS_THRESHOLD` lowered 1024 -> 128 (module-scope constant kept for external benchmarks that monkey-patch it; actual dispatch uses the new picker). Dispatcher in `gated_delta_rule_mtp` now calls the picker and passes tile_v into `gated_delta_rule_mtp_wide_vec`. 3. tests/gdn/test_decode_delta_rule.py — `test_gdn_decode_bf16_state_wide_vec_mtp_kernel` parametrized over `tile_v in {32, 64, 128}`. num_v_heads restricted to {64} (Qwen3.5 production shape; HV=32 exploratory coverage is not needed for the small-batch path). Total configs: 9 batch x 7 T x 2 cache x 3 tile_v = 378. Perf (B200, HV=64, H_Q=H_K=16, K=V=128, BF16, CUPTI 5 warmup + 50 bench iters, cache_intermediate=True, disable_state_update=True for T>=2; T=1 unchanged): post- 3127 before post tile_v picker B T (us) (us) speedup 1 2 5.44 5.46 1.00x (picker -> tv=32 ties baseline) 2 2 6.40 6.34 1.01x (picker -> tv=32) 4 2 9.63 8.14 1.18x (picker -> tv=32) 8 2 13.63 12.74 1.07x (picker -> tv=64) 16 2 21.20 21.18 1.00x (tv=128, unchanged) 32 2 37.81 37.89 1.00x (tv=128, unchanged) 64 2 70.69 70.69 1.00x (tv=128, unchanged) Biggest headline: B=4 T=2 goes from 9.63 us to 8.14 us (1.18x) by routing to wide_vec tile_v=32, which was previously dispatched to the baseline ILP=4 kernel. Cumulative against origin/main at B=4 T=2: 11.14 us -> 8.14 us = 1.37x. Correctness: 378 / 378 new pytest configs pass (test_gdn_decode_bf16_state_wide_vec_mtp_kernel). Equivalent 1008-config sanity run across tile_v = {32, 64, 128} x HV = {32, 64} also passed during development (not in CI to keep the default matrix size reasonable). AI-assisted by Claude Code. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
GDN BF16-state MTP — dispatcher output (B200)
Config: HV=64, H_Q=H_K=16, K=V=128, BF16, qk_l2norm=ON. T=1 uses state-update ON (no intermediate caching); T≥2 uses intermediate caching ON with state-update OFF. Measured via
benchmarks/bench_gdn_decode.py::bench_gdn_decode_bf16_stateDRAM SOL % (of 8.0 TB/s B200 peak)
Speedup vs
origin/main(8559397)🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes