diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 26c742e839..c0f2f6fbe4 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -737,7 +737,6 @@ def run_gdn_decode_kernel_small_batch_pretranspose( tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * batch_size * 4 / 1024 / 1024 vec_size = ( TILE_K // 32 @@ -840,18 +839,11 @@ def run_gdn_decode_kernel_big_batch_pretranspose( tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * batch_size * 4 / 1024 / 1024 vec_size = ( TILE_K // 32 ) # Each thread in a warp processes this many elements (always 4 for TILE_K=128) - # print(f"Batched CP.ASYNC Load + Store (bypass L1 cache)") - # print(f" {batch_size} batches x {v_dim}x{k_dim} matrices") - # print(f" Tile: {TILE_V}x{TILE_K}, {num_v_tiles} tiles/batch") - # print(f" Threads: {NUM_THREADS} ({NUM_THREADS // 32} warps), vec_size: {vec_size}") - # print(f" Total: {total_data_mb:.1f} MB\n") - # Create SMEM layout smem_layout_staged = cute.make_layout( (TILE_V, TILE_K, NUM_STAGES), stride=(TILE_K, 1, TILE_V * TILE_K) @@ -942,7 +934,7 @@ def gated_delta_rule_decode_pretranspose( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - state: torch.Tensor, + state: Optional[torch.Tensor], A_log: torch.Tensor, a: torch.Tensor, dt_bias: torch.Tensor, @@ -950,6 +942,8 @@ def gated_delta_rule_decode_pretranspose( scale: Optional[float] = None, output: Optional[torch.Tensor] = None, use_qk_l2norm: bool = True, + initial_state: Optional[torch.Tensor] = None, + initial_state_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Gated Delta Rule Decode kernel for single-token generation. @@ -963,10 +957,11 @@ def gated_delta_rule_decode_pretranspose( Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16. v (torch.Tensor): Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. - state (torch.Tensor): + state (Optional[torch.Tensor]): Current state of shape ``[B, HV, V, K]`` (v-major / K-last layout). - Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend + Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend when T in 1..4 and K=V=128. Will be updated in-place. + Pass ``None`` when using ``initial_state`` / ``initial_state_indices`` instead. A_log (torch.Tensor): Log decay parameter of shape ``[HV]``. Must be float32. a (torch.Tensor): @@ -982,32 +977,61 @@ def gated_delta_rule_decode_pretranspose( If None, will be allocated automatically. use_qk_l2norm (bool): Whether to apply L2 normalization to q and k. Default: ``True``. + initial_state (Optional[torch.Tensor]): + State pool of shape ``[pool_size, HV, V, K]`` (K-last / K-contiguous, + same layout as the per-batch ``state`` argument). + When provided, the kernel gathers directly from the pool using + ``initial_state_indices`` and writes updates back in-place — eliminating + the caller-side gather/scatter overhead. + Requires bfloat16 state with T in 1..4 and K=V=128 (bf16 fast path). + initial_state_indices (Optional[torch.Tensor]): + Per-batch indices of shape ``[B]`` (int32 or int64) mapping each batch + entry to its slot in ``initial_state``. Required when ``initial_state`` + is provided. Returns: Tuple[torch.Tensor, torch.Tensor]: - output: Output tensor of shape ``[B, 1, HV, V]`` - - state: Updated state tensor of shape ``[B, HV, V, K]`` + - state or initial_state: Updated state (in-place). Note: - - Requires SM90 (Hopper) architecture - - State is updated in-place + - Requires SM90+ (Hopper, Blackwell, etc.) + - State is always updated in-place; the pool path writes directly into + ``initial_state`` memory (no separate scatter step needed) - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 - and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used. + and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used + (supports both the direct ``state`` path and the pool+indices path). + - pool+indices (``initial_state``/``initial_state_indices``) only supported + via the bf16 fast path; float32 state raises an error. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes B, T, H, K = q.shape _, _, HV, V = v.shape - # Validate state shape (Qwen-style K-last: [B, HV, V, K]) - assert state.shape == (B, HV, V, K), ( - f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}" + use_pool = initial_state is not None + assert use_pool == (initial_state_indices is not None), ( + "initial_state and initial_state_indices must be provided together" ) + if use_pool: + pool_size = initial_state.shape[0] + assert initial_state.shape == (pool_size, HV, V, K), ( + f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], " + f"got {initial_state.shape}" + ) + else: + assert state is not None, "Either state or initial_state must be provided" + # Validate state shape (K-last: [B, HV, V, K]) + assert state.shape == (B, HV, V, K), ( + f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}" + ) + # Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V=128 + state_dtype = initial_state.dtype if use_pool else state.dtype use_gdn_decode_klast_bf16_state = ( _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE - and state.dtype == torch.bfloat16 + and state_dtype == torch.bfloat16 and T in (1, 2, 3, 4) and K == 128 and V == 128 @@ -1028,7 +1052,8 @@ def gated_delta_rule_decode_pretranspose( k=k, v=v, b=b, - initial_state_source=state, + initial_state_source=initial_state if use_pool else state, + initial_state_indices=initial_state_indices, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale_val, ) @@ -1040,9 +1065,14 @@ def gated_delta_rule_decode_pretranspose( output = out if output.dtype != target_dtype: output = output.to(target_dtype) - return output, state + return_state = initial_state if use_pool else state + return output, return_state - # Legacy path: T=1 only, float32 state + # Legacy path: T=1 only, float32 state (no pool+indices support) + assert not use_pool, ( + "pool+indices (initial_state/initial_state_indices) requires bfloat16 state " + "with T in 1..4 and K=V=128 (the gdn_decode_klast_bf16_state fast path)" + ) assert T == 1, f"Decode only supports T=1, got T={T}" assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" @@ -1143,19 +1173,18 @@ def gated_delta_rule_decode_pretranspose( # Run kernel directly with PyTorch tensors (no from_dlpack needed) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - cache["compiled"]( + compiled( h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream ) - # Copy state back only if state was not contiguous - # (if contiguous, reshape returns a view and kernel updated state in-place) - if not state.is_contiguous(): - state.copy_(h0_source.reshape(B, HV, V, K)) - # Convert output to target dtype if needed (kernel outputs bfloat16) if output.dtype != target_dtype: output = output.to(target_dtype) + # Copy state back only if state was not contiguous + # (if contiguous, reshape returns a view and kernel updated state in-place) + if not state.is_contiguous(): + state.copy_(h0_source.reshape(B, HV, V, K)) return output, state diff --git a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py index 9bbbd849c6..1b3e0ecc67 100644 --- a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -707,6 +707,7 @@ def gated_delta_rule_decode_kernel_seqlen1( gA_log: cute.Tensor, gdt_bias: cute.Tensor, gH: cute.Tensor, + gH_slot_indices: cute.Tensor, gO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -730,6 +731,7 @@ def gated_delta_rule_decode_kernel_seqlen1( batch_idx = bidx // HV value_head_idx = bidx % HV query_head_idx = value_head_idx // (HV // H) + pool_batch_idx = gH_slot_indices[batch_idx] smem = utils.SmemAllocator() @@ -768,7 +770,7 @@ def gated_delta_rule_decode_kernel_seqlen1( cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) ) - h_global = gH[(batch_idx, value_head_idx, None, None)] + h_global = gH[(pool_batch_idx, value_head_idx, None, None)] # Launch first 2 async loads load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) @@ -805,7 +807,7 @@ def gated_delta_rule_decode_kernel_seqlen1( for i in cutlass.range_constexpr(32): k_chunk[i] = k_sh[k_base + i] - h_out = gH[(batch_idx, value_head_idx, None, None)] + h_out = gH[(pool_batch_idx, value_head_idx, None, None)] o_head = gO[(batch_idx, 0, value_head_idx, None)] # ======================================================================== @@ -1114,7 +1116,8 @@ def gated_delta_rule_decode_kernel_seqlen234_unified( gb: cute.Tensor, # [B, T=2/3/4, HV] gA_log: cute.Tensor, # [HV] gdt_bias: cute.Tensor, # [HV] - gH: cute.Tensor, # [B, HV, V=128, K=128] - K-fast layout + gH: cute.Tensor, # [pool, HV, V=128, K=128] - K-fast layout + gH_slot_indices: cute.Tensor, # [B] indices mapping batch -> pool slot gO: cute.Tensor, # [B, T=2/3/4, HV, V=128] scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1139,6 +1142,7 @@ def gated_delta_rule_decode_kernel_seqlen234_unified( batch_idx = bidx // HV value_head_idx = bidx % HV query_head_idx = value_head_idx // (HV // H) + pool_batch_idx = gH_slot_indices[batch_idx] warp_idx = tidx // 32 lane_idx = tidx % 32 @@ -1226,7 +1230,7 @@ def gated_delta_rule_decode_kernel_seqlen234_unified( ) # Upfront H loading - h_global = gH[(batch_idx, value_head_idx, None, None)] + h_global = gH[(pool_batch_idx, value_head_idx, None, None)] load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) nvvm.cp_async_commit_group() load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) @@ -1288,7 +1292,7 @@ def gated_delta_rule_decode_kernel_seqlen234_unified( load_v_to_smem(v_head3, v_sh3, tidx) # Output pointers - tokens 0, 1 always - h_out = gH[(batch_idx, value_head_idx, None, None)] + h_out = gH[(pool_batch_idx, value_head_idx, None, None)] o_head0 = gO[(batch_idx, 0, value_head_idx, None)] o_head1 = gO[(batch_idx, 1, value_head_idx, None)] @@ -1498,6 +1502,7 @@ def gated_delta_rule_launch_seqlen1( mA_log: cute.Tensor, mdt_bias: cute.Tensor, mH: cute.Tensor, + mH_slot_indices: cute.Tensor, mO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1517,6 +1522,7 @@ def gated_delta_rule_launch_seqlen1( mA_log, mdt_bias, mH, + mH_slot_indices, mO, scale, softplus_beta, @@ -1544,6 +1550,7 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( gA_log: cute.Tensor, gdt_bias: cute.Tensor, gH: cute.Tensor, + gH_slot_indices: cute.Tensor, gO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1568,6 +1575,7 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( query_head_idx = value_head_idx // (HV // H) v_row_base = v_chunk_idx * 32 + pool_batch_idx = gH_slot_indices[batch_idx] smem = utils.SmemAllocator() @@ -1593,7 +1601,7 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) ) - h_global = gH[(batch_idx, value_head_idx, None, None)] + h_global = gH[(pool_batch_idx, value_head_idx, None, None)] load_h_chunk_async(h_sh_chunk, h_global, tidx, v_row_base) nvvm.cp_async_commit_group() @@ -1623,7 +1631,7 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( for i in cutlass.range_constexpr(32): k_chunk[i] = k_sh[k_base + i] - h_out = gH[(batch_idx, value_head_idx, None, None)] + h_out = gH[(pool_batch_idx, value_head_idx, None, None)] o_head = gO[(batch_idx, 0, value_head_idx, None)] nvvm.cp_async_wait_group(0) @@ -1706,6 +1714,7 @@ def gated_delta_rule_launch_seqlen1_lowBS_1chunk( mA_log: cute.Tensor, mdt_bias: cute.Tensor, mH: cute.Tensor, + mH_slot_indices: cute.Tensor, mO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1726,6 +1735,7 @@ def gated_delta_rule_launch_seqlen1_lowBS_1chunk( mA_log, mdt_bias, mH, + mH_slot_indices, mO, scale, softplus_beta, @@ -1748,6 +1758,7 @@ def gated_delta_rule_launch_seqlen2( mA_log: cute.Tensor, mdt_bias: cute.Tensor, mH: cute.Tensor, + mH_slot_indices: cute.Tensor, mO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1767,6 +1778,7 @@ def gated_delta_rule_launch_seqlen2( mA_log, mdt_bias, mH, + mH_slot_indices, mO, scale, softplus_beta, @@ -1790,6 +1802,7 @@ def gated_delta_rule_launch_seqlen3( mA_log: cute.Tensor, mdt_bias: cute.Tensor, mH: cute.Tensor, + mH_slot_indices: cute.Tensor, mO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1809,6 +1822,7 @@ def gated_delta_rule_launch_seqlen3( mA_log, mdt_bias, mH, + mH_slot_indices, mO, scale, softplus_beta, @@ -1832,6 +1846,7 @@ def gated_delta_rule_launch_seqlen4( mA_log: cute.Tensor, mdt_bias: cute.Tensor, mH: cute.Tensor, + mH_slot_indices: cute.Tensor, mO: cute.Tensor, scale: cutlass.Float32, softplus_beta: cutlass.Float32, @@ -1851,6 +1866,7 @@ def gated_delta_rule_launch_seqlen4( mA_log, mdt_bias, mH, + mH_slot_indices, mO, scale, softplus_beta, @@ -1942,8 +1958,10 @@ def gated_delta_rule( k: Key tensor [B, T, H, K] v: Value tensor [B, T, HV, V] b: Beta gate input [B, T, HV] - initial_state_source: H state [B, HV, V, K] (K-fast layout), modified in-place - initial_state_indices: Not used (for compatibility) + initial_state_source: H state [pool_size, HV, V, K] (K-fast layout), modified in-place. + For the direct path (no pool), pass [B, HV, V, K] and omit initial_state_indices. + initial_state_indices: Per-batch indices [B] (int32) mapping each batch entry to its + slot in initial_state_source. When None, uses identity mapping (arange(B)). use_qk_l2norm_in_kernel: Whether to L2-normalize Q/K in kernel (default: True) scale: Optional attention scale (default: 1/sqrt(K)) @@ -1984,10 +2002,19 @@ def gated_delta_rule( assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" HV = v.shape[2] V = v.shape[3] + pool_size = initial_state_source.shape[0] if scale is None: scale = 1.0 / math.sqrt(K) + # Resolve indices: identity mapping when not provided + if initial_state_indices is None: + h_slot_indices = torch.arange(B, dtype=torch.int32, device=q.device) + elif initial_state_indices.dtype != torch.int32: + h_slot_indices = initial_state_indices.to(torch.int32) + else: + h_slot_indices = initial_state_indices + output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) @@ -1998,6 +2025,7 @@ def gated_delta_rule( A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) h_ = from_dlpack(initial_state_source, assumed_align=32, enable_tvm_ffi=True) + h_slot_indices_ = from_dlpack(h_slot_indices, assumed_align=32, enable_tvm_ffi=True) o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) scale_f32 = cutlass.Float32(scale) @@ -2007,8 +2035,8 @@ def gated_delta_rule( stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Check cache - include all shape dimensions to avoid incorrect reuse - cache_key = (T, B, H, HV, K, V) + # Check cache - include pool_size so pool and direct paths don't collide + cache_key = (T, B, H, HV, K, V, pool_size) if cache_key not in _compiled_kernels: # Select and compile the appropriate kernel if T == 1 and B <= 4: @@ -2032,6 +2060,7 @@ def gated_delta_rule( A_log_, dt_bias_, h_, + h_slot_indices_, o_, scale_f32, softplus_beta_f32, @@ -2051,6 +2080,7 @@ def gated_delta_rule( A_log_, dt_bias_, h_, + h_slot_indices_, o_, scale_f32, softplus_beta_f32, diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 963198c8a6..dcd2b75dc9 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -395,6 +395,144 @@ def test_decode_kernel_basic_nontranspose( ) +# ============================================================================ +# Test pretranspose kernel with pool + indices path +# Verifies that passing initial_state=[pool,HV,V,K] + initial_state_indices=[B] +# produces identical output and in-place state updates as the gather-run-scatter +# direct-state path, and that unselected pool slots are untouched. +# ============================================================================ + + +def _test_decode_kernel_pretranspose_pool( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + pool_multiplier: int = 3, + state_dtype: str = "bfloat16", + seed: int | None = None, +): + """Pool+indices path must match gather → direct-state → scatter reference.""" + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + pool_size = batch_size * pool_multiplier + dtype_torch = getattr(torch, dtype) + kv_dtype = getattr(torch, state_dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.nn.functional.normalize( + torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch), + p=2.0, + dim=-1, + ) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + + # Pool in [pool, HV, V, K] K-last layout (same as the direct-state layout) + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=kv_dtype + ) + + # Non-trivial indices: every pool_multiplier-th slot (non-contiguous) + indices = ( + torch.arange(batch_size, dtype=torch.int32, device=device) * pool_multiplier + ) + + # ── Pool path (what we're testing) ────────────────────────────────────── + pool_under_test = pool.clone() + out_pool, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=indices, + ) + + # ── Direct-state reference (gather → kernel) ───────────────────────────── + gathered_state = pool[indices].clone() # [B, HV, V, K] + out_direct, updated_state = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=gathered_state, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + ) + + atol = 5e-3 + rtol = 5e-3 + + # Outputs must match + torch.testing.assert_close(out_pool, out_direct, atol=atol, rtol=rtol) + + # Selected pool slots must match the state updated by the direct path + torch.testing.assert_close( + pool_under_test[indices], updated_state, atol=atol, rtol=rtol + ) + + # Non-selected pool slots must be exactly unchanged + mask = torch.ones(pool_size, dtype=torch.bool, device=device) + mask[indices] = False + torch.testing.assert_close(pool_under_test[mask], pool[mask], atol=0.0, rtol=0.0) + + print( + f"✓ Pool+indices pretranspose test passed " + f"(batch={batch_size}, pool={pool_size}, dtype={dtype})" + ) + + +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_decode_kernel_pretranspose_pool( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + scale: float, + seed: int = int(os.environ.get("SEED", "0")), +): + _test_decode_kernel_pretranspose_pool( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + scale, + seed=seed, + ) + + # ============================================================================ # Test verify kernel with MTP version (Multiple Token Processing) # Reference: fp32 h state (default).