From 58cc26d2001e4a92f2e1b73b94df9a46386a8082 Mon Sep 17 00:00:00 2001 From: Svyatoslav Feldsherov Date: Sat, 28 Mar 2026 09:50:20 +0000 Subject: [PATCH] feat(gdn): separate input and output pool indices --- flashinfer/gdn_decode.py | 19 ++ .../gdn_kernels/gdn_decode_bf16_state.py | 61 +++++- .../gdn_kernels/gdn_decode_pretranspose.py | 46 ++++- tests/gdn/test_decode_delta_rule.py | 183 ++++++++++++++++++ 4 files changed, 292 insertions(+), 17 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 0d3410548c..ea679134a0 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -121,6 +121,7 @@ def gated_delta_rule_decode_pretranspose( use_qk_l2norm: bool = True, initial_state: Optional[torch.Tensor] = None, initial_state_indices: Optional[torch.Tensor] = None, + output_state_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Gated Delta Rule Decode kernel for single-token generation. @@ -165,6 +166,10 @@ def gated_delta_rule_decode_pretranspose( Per-batch indices of shape ``[B]`` (int32 or int64) mapping each batch entry to its slot in ``initial_state``. Required when ``initial_state`` is provided. + output_state_indices (Optional[torch.Tensor]): + Per-batch indices of shape ``[B]`` (int32 or int64) specifying where to write the updated state for each batch entry in the pool. + Requires ``initial_state`` to be provided. + If None, the kernel will write the updated state back to the same slot it read from (i.e., ``initial_state_indices``). Returns: Tuple[torch.Tensor, torch.Tensor]: @@ -191,6 +196,18 @@ def gated_delta_rule_decode_pretranspose( assert use_pool == (initial_state_indices is not None), ( "initial_state and initial_state_indices must be provided together" ) + if output_state_indices is not None: + assert use_pool, ( + "output_state_indices can only be used with initial_state (pool mode)" + ) + assert output_state_indices.shape == (B,), ( + f"Expected output_state_indices shape [{B}], " + f"got {output_state_indices.shape}" + ) + assert output_state_indices.dtype in (torch.int32, torch.int64), ( + f"output_state_indices must be int32 or int64, " + f"got {output_state_indices.dtype}" + ) if use_pool: pool_size = initial_state.shape[0] @@ -253,6 +270,7 @@ def gated_delta_rule_decode_pretranspose( b=b, initial_state_source=initial_state if use_pool else state, initial_state_indices=initial_state_indices, + output_state_indices=output_state_indices, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale_val, ) @@ -339,6 +357,7 @@ def gated_delta_rule_decode_pretranspose( use_qk_l2norm, use_pool_indexing=use_pool_indexing, initial_state_indices=initial_state_indices, + output_state_indices=output_state_indices, ) # Copy state back only if not using pool and state was not contiguous diff --git a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py index fa4a4a4f4f..4a8f00c968 100644 --- a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -1107,7 +1107,8 @@ def gdn_decode_bf16state_mtp_kernel( v: cute.Tensor, # [B, T, HV, V] b: cute.Tensor, # [B, T, HV] o: cute.Tensor, # [B, T, HV, V] - output - h0_indices: cute.Tensor, # [B] - initial state indices + h0_indices: cute.Tensor, # [B] - initial state indices (read) + h0_out_indices: cute.Tensor, # [B] - output state indices (write) softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], scale: cutlass.Constexpr[float], @@ -1320,6 +1321,8 @@ def gdn_decode_bf16state_mtp_kernel( # Each group handles tile_v/num_groups V rows, 8 at a time (ILP=8) flat_state_idx = cache_idx * HV + i_hv + write_cache_idx = h0_out_indices[i_n] + flat_write_idx = write_cache_idx * HV + i_hv rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups eighth_rows: cutlass.Constexpr[int] = rows_per_group // MTP_ILP_ROWS @@ -1975,14 +1978,38 @@ def gdn_decode_bf16state_mtp_kernel( r_hb5[i] = cutlass.BFloat16(r_h[5, i]) r_hb6[i] = cutlass.BFloat16(r_h[6, i]) r_hb7[i] = cutlass.BFloat16(r_h[7, i]) - cute.autovec_copy(r_hb0, ht0) - cute.autovec_copy(r_hb1, ht1) - cute.autovec_copy(r_hb2, ht2) - cute.autovec_copy(r_hb3, ht3) - cute.autovec_copy(r_hb4, ht4) - cute.autovec_copy(r_hb5, ht5) - cute.autovec_copy(r_hb6, ht6) - cute.autovec_copy(r_hb7, ht7) + wt0 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v0, lane_in_group) + ) + wt1 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v1, lane_in_group) + ) + wt2 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v2, lane_in_group) + ) + wt3 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v3, lane_in_group) + ) + wt4 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v4, lane_in_group) + ) + wt5 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v5, lane_in_group) + ) + wt6 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v6, lane_in_group) + ) + wt7 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v7, lane_in_group) + ) + cute.autovec_copy(r_hb0, wt0) + cute.autovec_copy(r_hb1, wt1) + cute.autovec_copy(r_hb2, wt2) + cute.autovec_copy(r_hb3, wt3) + cute.autovec_copy(r_hb4, wt4) + cute.autovec_copy(r_hb5, wt5) + cute.autovec_copy(r_hb6, wt6) + cute.autovec_copy(r_hb7, wt7) # ============================================================================== @@ -2003,6 +2030,7 @@ def run_gdn_decode_bf16state_mtp( b: cute.Tensor, o: cute.Tensor, h0_indices: cute.Tensor, + h0_out_indices: cute.Tensor, softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], scale: cutlass.Constexpr[float], @@ -2056,6 +2084,7 @@ def run_gdn_decode_bf16state_mtp( b, o, h0_indices, + h0_out_indices, softplus_beta, softplus_threshold, scale, @@ -2467,6 +2496,7 @@ def gated_delta_rule_mtp( b: Optional[torch.Tensor] = None, initial_state_source: Optional[torch.Tensor] = None, initial_state_indices: Optional[torch.Tensor] = None, + output_state_indices: Optional[torch.Tensor] = None, intermediate_states_buffer: Optional[torch.Tensor] = None, disable_state_update: bool = False, use_qk_l2norm_in_kernel: bool = True, @@ -2487,7 +2517,9 @@ def gated_delta_rule_mtp( v: [B, T, HV, V] bf16 b: [B, T, HV] bf16 initial_state_source: [pool_size, HV, V, K] bf16 - initial_state_indices: [B] int32 - indices into state pool + initial_state_indices: [B] int32 - indices into state pool (read) + output_state_indices: Optional [B] int32 - indices for writing updated state. + Defaults to initial_state_indices when None. intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16 disable_state_update: bool - if True, don't update initial state scale: Optional, default 1/sqrt(K) @@ -2514,6 +2546,12 @@ def gated_delta_rule_mtp( if initial_state_indices is None: initial_state_indices = torch.arange(B, dtype=torch.int32, device=q.device) + # Default output indices to read indices + if output_state_indices is None: + output_state_indices = initial_state_indices + elif output_state_indices.dtype != torch.int32: + output_state_indices = output_state_indices.to(torch.int32) + if output is None: output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) @@ -2552,6 +2590,7 @@ def gated_delta_rule_mtp( dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True) + h0_out_idx_ = from_dlpack(output_state_indices, assumed_align=32, enable_tvm_ffi=True) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -2590,6 +2629,7 @@ def gated_delta_rule_mtp( b_, o_, h0_idx_, + h0_out_idx_, softplus_beta, softplus_threshold, scale, @@ -2620,6 +2660,7 @@ def gated_delta_rule_mtp( b_, o_, h0_idx_, + h0_out_idx_, stream, ) diff --git a/flashinfer/gdn_kernels/gdn_decode_pretranspose.py b/flashinfer/gdn_kernels/gdn_decode_pretranspose.py index 7675d29111..ff2febe951 100644 --- a/flashinfer/gdn_kernels/gdn_decode_pretranspose.py +++ b/flashinfer/gdn_kernels/gdn_decode_pretranspose.py @@ -55,7 +55,8 @@ def gdn_decode_kernel_small_batch_pretranspose( v: cute.Tensor, # [B, T, HV, V] b: cute.Tensor, # [B, T, HV] o: cute.Tensor, # [B, T, HV, V] - output - h0_indices: cute.Tensor, # [B] - initial state indices + h0_indices: cute.Tensor, # [B] - initial state indices (read) + h0_out_indices: cute.Tensor, # [B] - output state indices (write) cu_seqlens: cute.Tensor, # [B+1] - cumulative sequence lengths (for varlen) softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], @@ -134,16 +135,18 @@ def gdn_decode_kernel_small_batch_pretranspose( # Compute state index: use pool indexing if enabled. if cutlass.const_expr(use_pool_indexing): pool_idx = h0_indices[i_n] + out_pool_idx = h0_out_indices[i_n] else: pool_idx = 0 + out_pool_idx = 0 if pool_idx >= 0: - # Get current state slice. + # Get current batch if cutlass.const_expr(use_pool_indexing): # h0_source layout: [pool_size, HV, V, K] (supports non-contiguous page stride) gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K) gDst = cute.local_tile( - h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0) + h0_source, (1, 1, TILE_V, TILE_K), (out_pool_idx, i_hv, None, 0) ) else: # h0_source layout: [B*HV, V, K] @@ -307,7 +310,7 @@ def gdn_decode_kernel_small_batch_pretranspose( r_h[i] += r_k[i] * v_new sum_hq += r_h[i] * r_q[i] - # Write h back to state. + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) if cutlass.const_expr(use_pool_indexing): gDst_tile = cute.local_tile( gDst, @@ -361,7 +364,8 @@ def gdn_decode_kernel_big_batch_pretranspose( v: cute.Tensor, # [B, T, HV, V] b: cute.Tensor, # [B, T, HV] o: cute.Tensor, # [B, T, HV, V] - output - h0_indices: cute.Tensor, # [B] - initial state indices + h0_indices: cute.Tensor, # [B] - initial state indices (read) + h0_out_indices: cute.Tensor, # [B] - output state indices (write) cu_seqlens: cute.Tensor, # [B+1] - cumulative sequence lengths (for varlen) softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], @@ -436,8 +440,10 @@ def gdn_decode_kernel_big_batch_pretranspose( # Compute state index: use pool indexing if enabled. if cutlass.const_expr(use_pool_indexing): pool_idx = h0_indices[i_n] + out_pool_idx = h0_out_indices[i_n] else: pool_idx = 0 + out_pool_idx = 0 if pool_idx >= 0: # Get current state slice. @@ -445,7 +451,7 @@ def gdn_decode_kernel_big_batch_pretranspose( # h0_source layout: [pool_size, HV, V, K] (supports non-contiguous page stride) gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] # (V, K) gDst = cute.local_tile( - h0_source, (1, 1, TILE_V, TILE_K), (pool_idx, i_hv, None, 0) + h0_source, (1, 1, TILE_V, TILE_K), (out_pool_idx, i_hv, None, 0) ) else: # h0_source layout: [B*HV, V, K] @@ -657,6 +663,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose( b: cute.Tensor, o: cute.Tensor, h0_indices: cute.Tensor, + h0_out_indices: cute.Tensor, cu_seqlens: cute.Tensor, softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], @@ -734,6 +741,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose( b, o, h0_indices, + h0_out_indices, cu_seqlens, softplus_beta, softplus_threshold, @@ -768,6 +776,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose( b: cute.Tensor, o: cute.Tensor, h0_indices: cute.Tensor, + h0_out_indices: cute.Tensor, cu_seqlens: cute.Tensor, softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], @@ -840,6 +849,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose( b, o, h0_indices, + h0_out_indices, cu_seqlens, softplus_beta, softplus_threshold, @@ -910,6 +920,7 @@ def run_pretranspose_decode( use_qk_l2norm: bool, use_pool_indexing: bool = False, initial_state_indices: Optional[torch.Tensor] = None, + output_state_indices: Optional[torch.Tensor] = None, ): """Compile and execute the pretranspose decode kernel. @@ -924,6 +935,8 @@ def run_pretranspose_decode( use_pool_indexing: Whether to use pool-based indirect state indexing. initial_state_indices: Int32 indices into state pool, shape [B]. Negative values indicate padding (kernel writes zeros). + output_state_indices: Optional int32 indices for write destination, shape [B]. + When None, writes go to the same slot as initial_state_indices. """ # Compile kernel with TVM FFI (cached) if use_pool_indexing: @@ -959,6 +972,11 @@ def run_pretranspose_decode( h0_indices = initial_state_indices.to(torch.int32) else: h0_indices = cache["h0_indices"] + # Resolve output indices: default to same as read indices + if use_pool_indexing and output_state_indices is not None: + h0_out_indices = output_state_indices.to(torch.int32) + else: + h0_out_indices = h0_indices cu_seqlens = cache["cu_seqlens"] if "compiled" not in cache: @@ -976,6 +994,7 @@ def run_pretranspose_decode( b_tensor = from_dlpack(b, assumed_align=16) o_tensor = from_dlpack(output, assumed_align=16) h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16) + h0_out_indices_tensor = from_dlpack(h0_out_indices, assumed_align=16) cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16) # Always use 8-CTA architecture (benchmarks show it's better for all batch sizes) @@ -994,6 +1013,7 @@ def run_pretranspose_decode( b_tensor, o_tensor, h0_indices_tensor, + h0_out_indices_tensor, cu_seqlens_tensor, softplus_beta=1.0, softplus_threshold=20.0, @@ -1018,5 +1038,17 @@ def run_pretranspose_decode( # Run kernel directly with PyTorch tensors (no from_dlpack needed) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) cache["compiled"]( - h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream + h0_source, + A_log, + a, + dt_bias, + q, + k, + v, + b, + output, + h0_indices, + h0_out_indices, + cu_seqlens, + stream, ) diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 1b43a0ddfe..35120483dd 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -2082,3 +2082,186 @@ def test_gdn_decode_bf16_state_mtp_kernel( " gdn_decode_bf16_state: pytest test_decode_delta_rule.py::test_gdn_decode_bf16_state_kernel -v" ) print(" ALL: pytest test_decode_delta_rule.py -v") + + +# ============================================================================ +# Tests for output_state_indices (separate read/write pool indices) +# ============================================================================ + + +@pytest.mark.parametrize("state_dtype", ["bfloat16", "float32"]) +@pytest.mark.parametrize("batch_size", [1, 4, 16]) +def test_output_state_indices(batch_size: int, state_dtype: str): + """Test that output_state_indices writes to different pool slots than read.""" + _skip_if_not_sm90_or_later() + + num_q_heads: int = 16 + num_k_heads: int = 16 + num_v_heads: int = 32 + head_size: int = 128 + + seed: int = 42 + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + pool_size = batch_size * 4 # plenty of room + dtype_torch = torch.bfloat16 + kv_dtype = getattr(torch, state_dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.nn.functional.normalize( + torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch), + p=2.0, + dim=-1, + ) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=kv_dtype + ) + + # Read from first batch_size slots, write to second batch_size slots + read_indices = torch.arange(batch_size, dtype=torch.int32, device=device) + write_indices = torch.arange( + batch_size, 2 * batch_size, dtype=torch.int32, device=device + ) + + pool_orig = pool.clone() + pool_under_test = pool.clone() + + out, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=1.0, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=read_indices, + output_state_indices=write_indices, + ) + + # Reference: direct state path (gather from read slots) + gathered = pool_orig[read_indices].clone() + out_ref, updated_ref = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=gathered, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=1.0, + use_qk_l2norm=True, + ) + + atol = 1e-3 + rtol = 1e-3 + + # Outputs must match + torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + + # Write slots must contain updated state + torch.testing.assert_close( + pool_under_test[write_indices], updated_ref, atol=atol, rtol=rtol + ) + + # Read slots must be unchanged (we wrote to different slots) + torch.testing.assert_close( + pool_under_test[read_indices], pool_orig[read_indices], atol=atol, rtol=rtol + ) + + # Other slots must be unchanged + used_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) + used_mask[read_indices] = True + used_mask[write_indices] = True + torch.testing.assert_close( + pool_under_test[~used_mask], pool_orig[~used_mask], atol=atol, rtol=rtol + ) + + +@pytest.mark.parametrize("state_dtype", ["bfloat16", "float32"]) +@pytest.mark.parametrize("batch_size", [1, 4, 16]) +def test_output_state_indices_same_as_input(batch_size: int, state_dtype: str): + """output_state_indices == initial_state_indices must match existing pool behavior.""" + _skip_if_not_sm90_or_later() + + torch.random.manual_seed(42) + torch.cuda.manual_seed(42) + + num_sab_heads = 32 + pool_size = batch_size * 3 + dtype_torch = torch.bfloat16 + kv_dtype = getattr(torch, state_dtype) + device = torch.device("cuda") + head_size = 128 + + with device: + q = torch.randn(batch_size, 1, 16, head_size, dtype=dtype_torch) + k = torch.nn.functional.normalize( + torch.randn(batch_size, 1, 16, head_size, dtype=dtype_torch), + p=2.0, + dim=-1, + ) + v = torch.randn(batch_size, 1, num_sab_heads, head_size, dtype=dtype_torch) + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=kv_dtype + ) + indices = torch.arange(batch_size, dtype=torch.int32, device=device) * 3 + + # Without output_state_indices + pool1 = pool.clone() + out1, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=1.0, + use_qk_l2norm=True, + initial_state=pool1, + initial_state_indices=indices, + ) + + # With output_state_indices == initial_state_indices + pool2 = pool.clone() + out2, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=1.0, + use_qk_l2norm=True, + initial_state=pool2, + initial_state_indices=indices, + output_state_indices=indices, + ) + atol = 1e-3 + rtol = 1e-3 + + torch.testing.assert_close(out1, out2, atol=atol, rtol=rtol) + torch.testing.assert_close(pool1, pool2, atol=atol, rtol=rtol)