Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions flashinfer/gdn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ def gated_delta_rule_decode_pretranspose(
- State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
(supports both the direct ``state`` path and the pool+indices path).
- pool+indices (``initial_state``/``initial_state_indices``) only supported
via the bf16 fast path; float32 state raises an error.
- pool+indices (``initial_state``/``initial_state_indices``) supported on
both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
(T=1). The float32 path also supports negative indices for padding.
- Legacy path (float32 state, T=1): K and V must be multiples of 4.
"""
# Validate input shapes
Expand Down Expand Up @@ -239,13 +240,17 @@ def gated_delta_rule_decode_pretranspose(
return_state = initial_state if use_pool else state
return output, return_state

# Legacy path: T=1 only, float32 state (no pool+indices support)
assert not use_pool, (
"pool+indices (initial_state/initial_state_indices) requires bfloat16 state "
"with T in 1..4 and K=V=128 (the gdn_decode_klast_bf16_state fast path)"
)
# Legacy path: T=1 only, float32 state (supports pool+indices via CuTe DSL kernel)
use_pool_indexing = initial_state_indices is not None
assert T == 1, f"Decode only supports T=1, got T={T}"
assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}"

if use_pool:
assert initial_state.dtype == torch.float32, (
f"initial_state must be float32 for legacy path, got {initial_state.dtype}"
)
else:
assert state is not None, "Either state or initial_state must be provided"
assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}"

# Validate K and V constraints
assert K >= 128, f"K must be at least 128, got K={K}"
Expand Down Expand Up @@ -273,8 +278,18 @@ def gated_delta_rule_decode_pretranspose(
# Kernel outputs bfloat16, allocate in that dtype first
output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device)

# Convert state from [B, HV, V, K] to [B*HV, V, K] for kernel
h0_source = state.reshape(B * HV, V, K)
# Build h0_source: [pool_size*HV, V, K] for kernel
if use_pool:
pool_size = initial_state.shape[0]
assert initial_state.is_contiguous(), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we consider to support non-contiguous state?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which situation do we need non-contiguous state?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses non-contiguous state

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses non-contiguous state

For non-contiguous states, we should be able to compute the true indices using strides. Once the assert is removed, it can also work with our kernel.

"initial_state (pool) must be contiguous for correct kernel pointer arithmetic"
)
h0_source = initial_state.reshape(pool_size * HV, V, K)
return_state = initial_state
else:
pool_size = B
h0_source = state.reshape(pool_size * HV, V, K)
return_state = state

# Execute kernel
run_pretranspose_decode(
Expand All @@ -295,18 +310,21 @@ def gated_delta_rule_decode_pretranspose(
V,
scale,
use_qk_l2norm,
use_pool_indexing=use_pool_indexing,
initial_state_indices=initial_state_indices,
)

# Copy state back only if state was not contiguous
# Copy state back only if not using pool and state was not contiguous
# (if contiguous, reshape returns a view and kernel updated state in-place)
if not state.is_contiguous():
# Pool path: kernel writes directly into initial_state via pool indices
if not use_pool and not state.is_contiguous():
state.copy_(h0_source.reshape(B, HV, V, K))

# Convert output to target dtype if needed (kernel outputs bfloat16)
if output.dtype != target_dtype:
output = output.to(target_dtype)

return output, state
return output, return_state


# ============================================================================
Expand Down
Loading
Loading