Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions flashinfer/gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,6 @@ def run_gdn_decode_kernel_small_batch_pretranspose(
tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout)

num_v_tiles = cute.ceil_div(v_dim, TILE_V)
v_dim * k_dim * batch_size * 4 / 1024 / 1024

vec_size = (
TILE_K // 32
Expand Down Expand Up @@ -840,18 +839,11 @@ def run_gdn_decode_kernel_big_batch_pretranspose(
tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout)

num_v_tiles = cute.ceil_div(v_dim, TILE_V)
v_dim * k_dim * batch_size * 4 / 1024 / 1024

vec_size = (
TILE_K // 32
) # Each thread in a warp processes this many elements (always 4 for TILE_K=128)

# print(f"Batched CP.ASYNC Load + Store (bypass L1 cache)")
# print(f" {batch_size} batches x {v_dim}x{k_dim} matrices")
# print(f" Tile: {TILE_V}x{TILE_K}, {num_v_tiles} tiles/batch")
# print(f" Threads: {NUM_THREADS} ({NUM_THREADS // 32} warps), vec_size: {vec_size}")
# print(f" Total: {total_data_mb:.1f} MB\n")

# Create SMEM layout
smem_layout_staged = cute.make_layout(
(TILE_V, TILE_K, NUM_STAGES), stride=(TILE_K, 1, TILE_V * TILE_K)
Expand Down Expand Up @@ -942,14 +934,16 @@ def gated_delta_rule_decode_pretranspose(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
state: torch.Tensor,
state: Optional[torch.Tensor],
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
b: torch.Tensor,
scale: Optional[float] = None,
output: Optional[torch.Tensor] = None,
use_qk_l2norm: bool = True,
initial_state: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Gated Delta Rule Decode kernel for single-token generation.

Expand All @@ -963,10 +957,11 @@ def gated_delta_rule_decode_pretranspose(
Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16.
v (torch.Tensor):
Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16.
state (torch.Tensor):
state (Optional[torch.Tensor]):
Current state of shape ``[B, HV, V, K]`` (v-major / K-last layout).
Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend
Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend
when T in 1..4 and K=V=128. Will be updated in-place.
Pass ``None`` when using ``initial_state`` / ``initial_state_indices`` instead.
A_log (torch.Tensor):
Log decay parameter of shape ``[HV]``. Must be float32.
a (torch.Tensor):
Expand All @@ -982,32 +977,61 @@ def gated_delta_rule_decode_pretranspose(
If None, will be allocated automatically.
use_qk_l2norm (bool):
Whether to apply L2 normalization to q and k. Default: ``True``.
initial_state (Optional[torch.Tensor]):
State pool of shape ``[pool_size, HV, V, K]`` (K-last / K-contiguous,
same layout as the per-batch ``state`` argument).
When provided, the kernel gathers directly from the pool using
``initial_state_indices`` and writes updates back in-place — eliminating
the caller-side gather/scatter overhead.
Requires bfloat16 state with T in 1..4 and K=V=128 (bf16 fast path).
initial_state_indices (Optional[torch.Tensor]):
Per-batch indices of shape ``[B]`` (int32 or int64) mapping each batch
entry to its slot in ``initial_state``. Required when ``initial_state``
is provided.

Returns:
Tuple[torch.Tensor, torch.Tensor]:
- output: Output tensor of shape ``[B, 1, HV, V]``
- state: Updated state tensor of shape ``[B, HV, V, K]``
- state or initial_state: Updated state (in-place).

Note:
- Requires SM90 (Hopper) architecture
- State is updated in-place
- Requires SM90+ (Hopper, Blackwell, etc.)
- State is always updated in-place; the pool path writes directly into
``initial_state`` memory (no separate scatter step needed)
- State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used.
and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
(supports both the direct ``state`` path and the pool+indices path).
- pool+indices (``initial_state``/``initial_state_indices``) only supported
via the bf16 fast path; float32 state raises an error.
- Legacy path (float32 state, T=1): K and V must be multiples of 4.
"""
# Validate input shapes
B, T, H, K = q.shape
_, _, HV, V = v.shape

# Validate state shape (Qwen-style K-last: [B, HV, V, K])
assert state.shape == (B, HV, V, K), (
f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}"
use_pool = initial_state is not None
assert use_pool == (initial_state_indices is not None), (
"initial_state and initial_state_indices must be provided together"
)

if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
Comment on lines +1017 to +1022
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

No bounds check on initial_state_indices values — out-of-bounds indices cause silent GPU memory corruption

initial_state.shape is validated, but individual index values in initial_state_indices are never checked against [0, pool_size). An out-of-range index produces a CUDA illegal-access fault or, worse, silently overwrites an adjacent allocation without any Python-level diagnostic.

🛡️ Proposed fix
 if use_pool:
     pool_size = initial_state.shape[0]
     assert initial_state.shape == (pool_size, HV, V, K), (
         f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
         f"got {initial_state.shape}"
     )
+    assert initial_state_indices.shape == (B,), (
+        f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}"
+    )
+    assert (
+        int(initial_state_indices.min()) >= 0
+        and int(initial_state_indices.max()) < pool_size
+    ), (
+        f"initial_state_indices values must be in [0, {pool_size}), "
+        f"got [{int(initial_state_indices.min())}, {int(initial_state_indices.max())}]"
+    )

Note: .min()/.max() trigger a host sync; if that's unacceptable on a hot path, guard this behind assert statements that can be compiled-out, or document the constraint clearly in the docstring.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.shape == (pool_size, HV, V, K), (
f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], "
f"got {initial_state.shape}"
)
assert initial_state_indices.shape == (B,), (
f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}"
)
assert (
int(initial_state_indices.min()) >= 0
and int(initial_state_indices.max()) < pool_size
), (
f"initial_state_indices values must be in [0, {pool_size}), "
f"got [{int(initial_state_indices.min())}, {int(initial_state_indices.max())}]"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 1017 - 1022, The code validates
initial_state shape but never checks values in initial_state_indices; add a
bounds check (e.g., compute min and max of initial_state_indices and compare
against 0 and pool_size) and raise a clear error (IndexError or ValueError) if
any index is out of range before using initial_state_indices in use_pool path;
reference the variables initial_state_indices, initial_state, pool_size and the
use_pool branch and prefer using assert for hot-path avoidance if you need this
check to be compiled out (or document the requirement in the function
docstring).

else:
assert state is not None, "Either state or initial_state must be provided"
# Validate state shape (K-last: [B, HV, V, K])
assert state.shape == (B, HV, V, K), (
f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}"
)

# Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V=128
state_dtype = initial_state.dtype if use_pool else state.dtype
use_gdn_decode_klast_bf16_state = (
_GDN_DECODE_KLAST_BF16_STATE_AVAILABLE
and state.dtype == torch.bfloat16
and state_dtype == torch.bfloat16
and T in (1, 2, 3, 4)
and K == 128
and V == 128
Expand All @@ -1028,7 +1052,8 @@ def gated_delta_rule_decode_pretranspose(
k=k,
v=v,
b=b,
initial_state_source=state,
initial_state_source=initial_state if use_pool else state,
initial_state_indices=initial_state_indices,
use_qk_l2norm_in_kernel=use_qk_l2norm,
scale=scale_val,
)
Expand All @@ -1040,9 +1065,14 @@ def gated_delta_rule_decode_pretranspose(
output = out
if output.dtype != target_dtype:
output = output.to(target_dtype)
return output, state
return_state = initial_state if use_pool else state
return output, return_state

# Legacy path: T=1 only, float32 state
# Legacy path: T=1 only, float32 state (no pool+indices support)
assert not use_pool, (
"pool+indices (initial_state/initial_state_indices) requires bfloat16 state "
"with T in 1..4 and K=V=128 (the gdn_decode_klast_bf16_state fast path)"
)
assert T == 1, f"Decode only supports T=1, got T={T}"
assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}"

Expand Down Expand Up @@ -1143,19 +1173,18 @@ def gated_delta_rule_decode_pretranspose(

# Run kernel directly with PyTorch tensors (no from_dlpack needed)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache["compiled"](
compiled(
h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream
)

# Copy state back only if state was not contiguous
# (if contiguous, reshape returns a view and kernel updated state in-place)
if not state.is_contiguous():
state.copy_(h0_source.reshape(B, HV, V, K))

# Convert output to target dtype if needed (kernel outputs bfloat16)
if output.dtype != target_dtype:
output = output.to(target_dtype)

# Copy state back only if state was not contiguous
# (if contiguous, reshape returns a view and kernel updated state in-place)
if not state.is_contiguous():
state.copy_(h0_source.reshape(B, HV, V, K))
return output, state


Expand Down
Loading
Loading