Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8a6e981
perf(gdn): fix bf16_state T=1 per-call overhead and add pool+padding …
ameynaik-hub Apr 19, 2026
1ad4dbe
perf(gdn): add wide_vec BF16 MTP kernel, auto-dispatch, and T=2 heuri…
ameynaik-hub Apr 20, 2026
bb558d5
perf(gdn): restore ILP=4 MTP launcher for small-batch BF16 MTP decode
ameynaik-hub Apr 20, 2026
316c36b
perf(gdn): extend wide_vec fast path to small batches via tile_v
ameynaik-hub Apr 21, 2026
241bdd2
perf(gdn): route T=1 decode through wide_vec at large work_units (poo…
ameynaik-hub Apr 22, 2026
6b7a731
chore(gdn): remove dead cooprow BF16 decode kernel and launcher
ameynaik-hub Apr 27, 2026
ee54669
chore(gdn): drop unused cpasync import after cooprow removal
ameynaik-hub Apr 27, 2026
898171f
fix(gdn): drop dead import of _reference_gdn_mtp from wide_vec module
ameynaik-hub Apr 27, 2026
7aa0141
chore(gdn): drop cooprow-era module constants and stale module docstring
ameynaik-hub Apr 27, 2026
936231f
chore(gdn): drop redundant V!=128 gate in _select_wide_vec_tile_v
ameynaik-hub Apr 27, 2026
088d879
chore(gdn): make BF16 GDN kernels pool-only; auto-promote at wrapper
ameynaik-hub Apr 27, 2026
0a92290
chore(gdn): remove dead T=1 ILP=8 kernel; route T=1 fallback through MTP
ameynaik-hub Apr 27, 2026
97c71f3
test(gdn): pass initial_state_indices=arange(B) for pool-only BF16 ke…
ameynaik-hub Apr 27, 2026
4da2b2a
bench(gdn): add --pool-mode {single,split} for BF16 state benchmark
ameynaik-hub Apr 27, 2026
e108118
feat(gdn): add split-pool support to gdn_wide_vec_kernel
ameynaik-hub Apr 27, 2026
0b7b66f
feat(gdn): add split-pool support to gdn_decode_bf16state_mtp_ilp4_ke…
ameynaik-hub Apr 27, 2026
3260f43
chore(gdn): remove dead gdn_decode_bf16state_mtp_kernel and ILP=8 path
ameynaik-hub Apr 27, 2026
9e475a5
test(gdn): add split-pool MTP coverage for wide_vec and mtp_ilp4
ameynaik-hub Apr 27, 2026
49d9d2a
refactor(gdn): merge gdn_decode_bf16_state_wide_vec.py into the main …
ameynaik-hub Apr 27, 2026
a234644
fix(gdn): use batch-scoped i_n for intermediate_states indexing (OOB;…
ameynaik-hub Apr 27, 2026
aa63cba
chore(gdn): refresh stale comments / docstrings post-cleanup
ameynaik-hub Apr 28, 2026
e1f6c53
perf(gdn): elide write-side address arithmetic when reads/writes alia…
ameynaik-hub Apr 28, 2026
7266b5e
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub Apr 28, 2026
4c7c6c9
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub Apr 28, 2026
f8c73dc
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub Apr 30, 2026
bab0309
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub May 1, 2026
bc6c269
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub May 1, 2026
51ccd37
test(gdn): chunk wide_vec MTP intermediate-state assert to avoid OOM
ameynaik-hub May 3, 2026
e61db27
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub May 4, 2026
1133601
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub May 4, 2026
89c2ebe
Merge branch 'main' into ameyn/wide_vec_t1
ameynaik-hub May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions benchmarks/bench_gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,7 +1836,7 @@ def gdn_decode_bf16_state_wrapper(
q: torch.Tensor, # [B, T, H_Q, K]
k: torch.Tensor, # [B, T, H_K, K]
v: torch.Tensor, # [B, T, HV, V]
state: torch.Tensor, # [B, HV, V, K] - K-last layout (pretranspose)
state: torch.Tensor, # [pool_size, HV, V, K] BF16 (K-last layout)
A_log: torch.Tensor, # [HV]
a: torch.Tensor, # [B, T, HV]
dt_bias: torch.Tensor, # [HV]
Expand All @@ -1849,11 +1849,14 @@ def gdn_decode_bf16_state_wrapper(
intermediate_states_buffer=None,
disable_state_update: bool = False,
initial_state_indices=None,
output_state_indices=None,
):
"""
Wrapper for gdn_decode_bf16_state GDN kernel.
Supports T=1 (calls gated_delta_rule) and T>1 (calls gated_delta_rule_mtp).
Adapts the interface to match the benchmark's calling convention.
Both pool-only paths require initial_state_indices to be passed by the
caller. When output_state_indices is non-None and differs from
initial_state_indices, the call exercises the split-pool dispatch.

Note: The kernel returns output directly, no copy needed.
"""
Expand All @@ -1874,6 +1877,8 @@ def gdn_decode_bf16_state_wrapper(
v=v,
b=b,
initial_state_source=state,
initial_state_indices=initial_state_indices,
output_state_indices=output_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale,
)
Expand All @@ -1890,6 +1895,7 @@ def gdn_decode_bf16_state_wrapper(
b=b,
initial_state_source=state,
initial_state_indices=initial_state_indices,
output_state_indices=output_state_indices,
intermediate_states_buffer=intermediate_states_buffer,
disable_state_update=disable_state_update,
use_qk_l2norm_in_kernel=use_qk_l2norm,
Expand Down Expand Up @@ -2246,12 +2252,24 @@ def bench_gdn_decode_bf16_state(
disable_state_update: bool = False,
warmup_iters: int = 10,
bench_iters: int = 100,
pool_mode: str = "single",
):
"""Benchmark BF16 state kernel."""
"""Benchmark BF16 state kernel.

pool_mode:
- "single": [B, HV, V, K] state, indices = arange(B), output_indices = None
(read == write). Exercises wide_vec single-pool / mtp fallbacks.
- "split": [2B, HV, V, K] state, read indices = arange(B),
write indices = arange(B, 2B). Exercises the split-pool dispatch
(speculative-decoding / MTP-verify shape).
"""
if not GDN_DECODE_BF16_STATE_AVAILABLE:
raise RuntimeError("gdn_decode_bf16_state kernel is not available")

assert seq_len >= 1, f"seq_len must be >= 1, got T={seq_len}"
assert pool_mode in ("single", "split"), (
f"BF16 state path supports pool_mode in {{single, split}}, got {pool_mode}"
)

num_o_heads = max(num_q_heads, num_v_heads)
num_sab_heads = num_o_heads
Expand All @@ -2268,17 +2286,23 @@ def bench_gdn_decode_bf16_state(
dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda")
b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda")

# Initial state: [B, HV, V, K] (K-fast layout, BF16)
# Pool sized for the indexing mode: split needs 2B slots so write slots
# are distinct from read slots.
pool_size = 2 * batch_size if pool_mode == "split" else batch_size
state = torch.randn(
batch_size,
pool_size,
num_sab_heads,
head_size,
head_size,
dtype=torch.bfloat16,
device="cuda",
)

# Intermediate states buffer (MTP only, when caching is enabled)
# Intermediate states buffer (MTP only, when caching is enabled).
# The kernel keys the cache by READ indices (cache_idx * T * HV + ...),
# and we use read_indices = arange(B) in both pool modes, so the buffer's
# first dim is always B regardless of pool_mode (the upper half of the
# pool, used only for split-mode writes, doesn't appear in the cache).
intermediate_states_buffer = None
if cache_intermediate_states and T > 1:
intermediate_states_buffer = torch.zeros(
Expand All @@ -2291,11 +2315,19 @@ def bench_gdn_decode_bf16_state(
device="cuda",
)

# Pre-allocate output and state indices (avoid per-call torch.arange overhead in CUPTI)
# Pre-allocate output and state indices (avoid per-call torch.arange
# overhead in CUPTI). For split-pool, write indices point into the upper
# half of the pool so reads and writes don't alias.
output = torch.empty(
batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda"
)
initial_state_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
if pool_mode == "split":
output_state_indices = torch.arange(
batch_size, 2 * batch_size, dtype=torch.int32, device="cuda"
)
else:
output_state_indices = None

# Scale factor
scale = 1.0 / (head_size**0.5)
Expand All @@ -2317,6 +2349,7 @@ def bench_gdn_decode_bf16_state(
intermediate_states_buffer=intermediate_states_buffer,
disable_state_update=disable_state_update,
initial_state_indices=initial_state_indices,
output_state_indices=output_state_indices,
),
enable_cupti=True,
dry_run_iters=warmup_iters,
Expand Down Expand Up @@ -2369,6 +2402,7 @@ def run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm):

cache_intermediate = getattr(args, "cache_intermediate_states", False)
disable_state_update = not getattr(args, "update_state", False)
pool_mode = getattr(args, "pool_mode", "single")

print("\n" + "=" * 100)
print(f"BF16 State GDN Benchmark (T={valid_seq_lens})")
Expand All @@ -2377,7 +2411,8 @@ def run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm):
f"v_heads={args.num_v_heads}, head_size={args.head_size}, "
f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}, "
f"cache_intermediate={'ON' if cache_intermediate else 'OFF'}, "
f"update_state={'ON' if not disable_state_update else 'OFF'}"
f"update_state={'ON' if not disable_state_update else 'OFF'}, "
f"pool_mode={pool_mode}"
)
print("=" * 100)
print()
Expand All @@ -2401,6 +2436,7 @@ def run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm):
disable_state_update=disable_state_update,
warmup_iters=args.warmup,
bench_iters=args.iters,
pool_mode=pool_mode,
)
all_results.append(result)

Expand Down Expand Up @@ -2794,6 +2830,19 @@ def main():
action="store_true",
help="Update final state (disable_state_update=False) for MTP benchmark",
)
parser.add_argument(
"--pool-mode",
choices=("single", "split"),
default="single",
help=(
"Pool indexing mode for the BF16 state benchmark. "
"'single' (default): treat the [B, HV, V, K] state as a pool of "
"size B with sequential indices arange(B) (read==write). "
"'split': allocate a pool of size 2*B; reads from slots [0..B), "
"writes to slots [B..2B), exercising the split-pool dispatch "
"(speculative-decoding / MTP-verify shape)."
),
)
parser.add_argument(
"--warmup",
type=int,
Expand Down
22 changes: 16 additions & 6 deletions flashinfer/gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,16 @@ def gated_delta_rule_decode_pretranspose(
)
assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}"
scale_val = K**-0.5 if scale is None else scale
if T == 1 and not use_pool:
# T=1 kernel does not accept initial_state_indices
# The BF16 path is pool-only. When the caller uses non-pool semantics
# (passes ``state`` instead of ``initial_state``), treat ``state`` as
# a pool of size B and synthesize sequential indices arange(B).
if use_pool:
bf16_pool = initial_state
bf16_indices = initial_state_indices
else:
bf16_pool = state
bf16_indices = torch.arange(B, dtype=torch.int32, device=q.device)
Comment on lines +267 to +275
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reapply the K-contiguous guard on the synthetic-pool BF16 path.

The real pool path already rejects tensors with stride(-1) != 1, but this new use_pool=False branch skips that check and passes state straight into the BF16 kernel as initial_state_source. A non-K-contiguous [B, HV, V, K] view can still satisfy the shape check and then produce wrong reads/writes in the fast path.

Suggested fix
         if use_pool:
             bf16_pool = initial_state
             bf16_indices = initial_state_indices
         else:
+            assert state.stride(-1) == 1, (
+                "state must be K-contiguous (stride[-1] == 1) for bf16 pretranspose decode, "
+                f"got stride={state.stride()}"
+            )
             bf16_pool = state
             bf16_indices = torch.arange(B, dtype=torch.int32, device=q.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 267 - 275, The synthetic-pool BF16
path bypasses the K-contiguous guard and can pass a non-K-contiguous `state`
into the BF16 fast path; update the `use_pool == False` branch to enforce the
same K-contiguity check used for the real pool: verify `state.stride(-1) == 1`
(or equivalently ensure last-dim contiguous) before assigning `bf16_pool =
state`, and if the check fails either raise an informative error or make a
contiguous copy (e.g., `state = state.contiguous()` for the last dimension) so
`initial_state_source`/`bf16_pool` is K-contiguous; keep the existing setting of
`bf16_indices` (the `torch.arange(B, ...)` line) unchanged.

if T == 1:
out = _gated_delta_rule_bf16_state(
A_log=A_log,
a=a,
Expand All @@ -276,12 +284,14 @@ def gated_delta_rule_decode_pretranspose(
k=k,
v=v,
b=b,
initial_state_source=state,
initial_state_source=bf16_pool,
initial_state_indices=bf16_indices,
output_state_indices=output_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
)
else:
# MTP kernel supports T>=1 and pool+indices
# MTP kernel for T>1 (supports pool+indices and intermediate caching)
out = _gated_delta_rule_bf16_state_mtp(
A_log=A_log,
a=a,
Expand All @@ -292,8 +302,8 @@ def gated_delta_rule_decode_pretranspose(
k=k,
v=v,
b=b,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
initial_state_source=bf16_pool,
initial_state_indices=bf16_indices,
output_state_indices=output_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
Expand Down
Loading
Loading