diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index c1a90c9fbe..af810232ea 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -264,8 +264,11 @@ def gated_delta_rule_decode_pretranspose( ) assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" scale_val = K**-0.5 if scale is None else scale - if T == 1 and not use_pool: - # T=1 kernel does not accept initial_state_indices + if T == 1: + # T=1 path handles both no-pool and pool+indices (with -1 padding). + # Dispatches internally to the ILP kernel for B>=16 or the MTP-T=1 + # kernel for B<16; both accept initial_state_indices since the + # pool+padding refactor. out = _gated_delta_rule_bf16_state( A_log=A_log, a=a, @@ -276,12 +279,14 @@ def gated_delta_rule_decode_pretranspose( k=k, v=v, b=b, - initial_state_source=state, + initial_state_source=initial_state if use_pool else state, + initial_state_indices=initial_state_indices, + output_state_indices=output_state_indices, use_qk_l2norm_in_kernel=use_qk_l2norm, scale=scale_val, ) else: - # MTP kernel supports T>=1 and pool+indices + # MTP kernel for T>1 (supports pool+indices and intermediate caching) out = _gated_delta_rule_bf16_state_mtp( A_log=A_log, a=a, diff --git a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py index b559c37731..a0ea0627f0 100644 --- a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -385,7 +385,7 @@ def gdn_decode_bf16state_cooprow_kernel( @cute.kernel def gdn_decode_bf16state_ilp_kernel( - h0_source: cute.Tensor, # [B*HV, V, K] as BF16 (K-last, autovec_copy compatible) + h0_source: cute.Tensor, # [pool_size*HV, V, K] as BF16 (K-last, autovec_copy compatible) vec_size: cutlass.Constexpr[int], num_v_tiles: cutlass.Constexpr[int], tile_v: cutlass.Constexpr[int], @@ -397,6 +397,8 @@ def gdn_decode_bf16state_ilp_kernel( v: cute.Tensor, # [B, 1, HV, V] b: cute.Tensor, # [B, 1, HV] o: cute.Tensor, # [B, 1, HV, V] - output + h0_indices: cute.Tensor, # [B] int32 - initial state pool slot to read + h0_out_indices: cute.Tensor, # [B] int32 - pool slot to write updated state softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], scale: cutlass.Constexpr[float], @@ -410,6 +412,8 @@ def gdn_decode_bf16state_ilp_kernel( """ ILP-optimized T=1 GDN decode kernel with BF16 state. Direct GMEM->register loads with 8-row ILP for high memory throughput. + Pool+padding: reads state from h0_indices[i_n]; writes to h0_out_indices[i_n]. + Negative indices redirect to slot 0 (caller must reserve slot 0 as null buffer). """ tidx, _, _ = cute.arch.thread_idx() lane_id = tidx % 32 @@ -547,7 +551,15 @@ def gdn_decode_bf16state_ilp_kernel( # =================================================================== # Main loop: process V rows with 8-row ILP # =================================================================== - flat_state_idx = i_n * HV + i_hv + # Pool+padding: read slot, write slot, redirect negatives to slot 0. + cache_idx = h0_indices[i_n] + if cache_idx < 0: + cache_idx = cutlass.Int32(0) + flat_state_idx = cache_idx * HV + i_hv + write_cache_idx = h0_out_indices[i_n] + if write_cache_idx < 0: + write_cache_idx = cutlass.Int32(0) + flat_write_idx = write_cache_idx * HV + i_hv rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups eighth_rows: cutlass.Constexpr[int] = rows_per_group // ILP_ROWS @@ -1056,6 +1068,7 @@ def gdn_decode_bf16state_ilp_kernel( cute.autovec_copy(r_o_bf16_vec, ot_slice) # Write updated H back to GMEM: FP32 regs -> BF16 regs -> GMEM BF16 (vectorized) + # Writes go to flat_write_idx (separate from read; supports in/out pool split). for i in cutlass.range_constexpr(vec_size): r_hb0[i] = cutlass.BFloat16(r_h[0, i]) r_hb1[i] = cutlass.BFloat16(r_h[1, i]) @@ -1065,14 +1078,38 @@ def gdn_decode_bf16state_ilp_kernel( r_hb5[i] = cutlass.BFloat16(r_h[5, i]) r_hb6[i] = cutlass.BFloat16(r_h[6, i]) r_hb7[i] = cutlass.BFloat16(r_h[7, i]) - cute.autovec_copy(r_hb0, ht0) - cute.autovec_copy(r_hb1, ht1) - cute.autovec_copy(r_hb2, ht2) - cute.autovec_copy(r_hb3, ht3) - cute.autovec_copy(r_hb4, ht4) - cute.autovec_copy(r_hb5, ht5) - cute.autovec_copy(r_hb6, ht6) - cute.autovec_copy(r_hb7, ht7) + wt0 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v0, lane_in_group) + ) + wt1 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v1, lane_in_group) + ) + wt2 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v2, lane_in_group) + ) + wt3 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v3, lane_in_group) + ) + wt4 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v4, lane_in_group) + ) + wt5 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v5, lane_in_group) + ) + wt6 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v6, lane_in_group) + ) + wt7 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_write_idx, v7, lane_in_group) + ) + cute.autovec_copy(r_hb0, wt0) + cute.autovec_copy(r_hb1, wt1) + cute.autovec_copy(r_hb2, wt2) + cute.autovec_copy(r_hb3, wt3) + cute.autovec_copy(r_hb4, wt4) + cute.autovec_copy(r_hb5, wt5) + cute.autovec_copy(r_hb6, wt6) + cute.autovec_copy(r_hb7, wt7) # ============================================================================== @@ -2117,7 +2154,7 @@ def run_gdn_decode_bf16state_mtp( @cute.jit def run_gdn_decode_bf16state_ilp( - h0_source: cute.Tensor, # [B*HV, V, K] BF16 + h0_source: cute.Tensor, # [pool_size*HV, V, K] BF16 A_log: cute.Tensor, a: cute.Tensor, dt_bias: cute.Tensor, @@ -2126,6 +2163,8 @@ def run_gdn_decode_bf16state_ilp( v: cute.Tensor, b: cute.Tensor, o: cute.Tensor, + h0_indices: cute.Tensor, + h0_out_indices: cute.Tensor, softplus_beta: cutlass.Constexpr[float], softplus_threshold: cutlass.Constexpr[float], scale: cutlass.Constexpr[float], @@ -2167,6 +2206,8 @@ def run_gdn_decode_bf16state_ilp( v, b, o, + h0_indices, + h0_out_indices, softplus_beta, softplus_threshold, scale, @@ -2294,6 +2335,11 @@ def run_gdn_decode_bf16state_cooprow( # Number of SMs on target GPU (detected dynamically) NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count +# GPU architecture detected once at import time — avoids per-call +# torch.cuda.get_device_capability() in the hot path. +_GPU_MAJOR, _ = torch.cuda.get_device_capability(0) +_USE_PACKED_FMA = _GPU_MAJOR >= 10 + def _select_tile_v_for_batch(B: int, HV: int, V: int) -> int: """Select optimal tile_v for the ILP kernel based on batch size. @@ -2326,6 +2372,9 @@ def gated_delta_rule( v: Optional[torch.Tensor] = None, b: Optional[torch.Tensor] = None, initial_state_source: Optional[torch.Tensor] = None, + initial_state_indices: Optional[torch.Tensor] = None, + output_state_indices: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, use_qk_l2norm_in_kernel: bool = True, scale: Optional[float] = None, ) -> torch.Tensor: @@ -2340,7 +2389,13 @@ def gated_delta_rule( k: [B, 1, H, K] bf16 v: [B, 1, HV, V] bf16 b: [B, 1, HV] bf16 - initial_state_source: [B, HV, V, K] bf16 (modified in-place) + initial_state_source: [B, HV, V, K] when no indices, else + [pool_size, HV, V, K] bf16 (modified in-place) + initial_state_indices: Optional [B] int32 — pool slots to read. + Negative entries redirect to slot 0 (null buffer). + output_state_indices: Optional [B] int32 — pool slots to write. + Defaults to initial_state_indices when None. + output: Optional pre-allocated [B, 1, HV, V] bf16 output scale: Optional, default 1/sqrt(K) Returns: @@ -2361,7 +2416,11 @@ def gated_delta_rule( if scale is None: scale = 1.0 / math.sqrt(K) - # Small batch: route through MTP kernel (T=1 path) with identity indices. + use_pool = initial_state_indices is not None + if output_state_indices is not None and output_state_indices.dtype != torch.int32: + output_state_indices = output_state_indices.to(torch.int32) + + # Small batch: route through MTP kernel (T=1 path). # The cooprow kernel has known correctness issues at small batch sizes (e.g. B=2). # The MTP kernel's T=1 path uses the same ILP-style computation and is well-tested. if B < ILP_BATCH_THRESHOLD: @@ -2376,29 +2435,20 @@ def gated_delta_rule( v=v, b=b, initial_state_source=initial_state_source, + initial_state_indices=initial_state_indices, + output_state_indices=output_state_indices, + output=output, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, scale=scale, ) - output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) - - # Reshape state to [B*HV, V, K] - h0_source = initial_state_source.reshape(B * HV, V, K) - - q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) - k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) - v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) - a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) - b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) - A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) - dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) - h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) - o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) + # Reshape state: no-pool [B, HV, V, K] -> [B*HV, V, K]; + # pool [pool_size, HV, V, K] -> [pool_size*HV, V, K]. + pool_size = initial_state_source.shape[0] if use_pool else B + h0_source = initial_state_source.reshape(pool_size * HV, V, K) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - major, _ = torch.cuda.get_device_capability(q.device) - use_packed_fma = major >= 10 + use_packed_fma = _USE_PACKED_FMA # B >= ILP_BATCH_THRESHOLD (small B handled by MTP path above) tile_v = _select_tile_v_for_batch(B, HV, V) @@ -2409,6 +2459,7 @@ def gated_delta_rule( HV, K, V, + pool_size, tile_v, scale, softplus_beta, @@ -2416,48 +2467,86 @@ def gated_delta_rule( use_packed_fma, ) if cache_key not in _compiled_kernels_ilp: + # First call for this shape: allocate defaults and do dlpack + # conversions once for compilation. Steady-state calls reuse + # the cached defaults and pass torch tensors straight to the + # compiled callable (tvm-ffi accepts either). + default_output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + default_indices = torch.arange(B, dtype=torch.int32, device=q.device) + + q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) + k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) + v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) + a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) + b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) + A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) + dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) + h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) + o_ = from_dlpack(default_output, assumed_align=32, enable_tvm_ffi=True) + h0_idx_ = from_dlpack(default_indices, assumed_align=32, enable_tvm_ffi=True) + h0_out_idx_ = from_dlpack( + default_indices, assumed_align=32, enable_tvm_ffi=True + ) + # Use maxrregcount=64 for smaller tile_v to improve occupancy # when grid size is small (fewer waves) if tile_v < 128: compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3 --ptxas-options=-maxrregcount=64" else: compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3" - _compiled_kernels_ilp[cache_key] = cute.compile( - run_gdn_decode_bf16state_ilp, - h_, - A_log_, - a_, - dt_bias_, - q_, - k_, - v_, - b_, - o_, - softplus_beta, - softplus_threshold, - scale, - HV, - B, - H, - K, - V, - use_qk_l2norm_in_kernel, - use_packed_fma, - tile_v, - stream, - options=compile_opts, - ) + _compiled_kernels_ilp[cache_key] = { + "compiled": cute.compile( + run_gdn_decode_bf16state_ilp, + h_, + A_log_, + a_, + dt_bias_, + q_, + k_, + v_, + b_, + o_, + h0_idx_, + h0_out_idx_, + softplus_beta, + softplus_threshold, + scale, + HV, + B, + H, + K, + V, + use_qk_l2norm_in_kernel, + use_packed_fma, + tile_v, + stream, + options=compile_opts, + ), + "output": default_output, + "default_indices": default_indices, + } + + cache = _compiled_kernels_ilp[cache_key] + + if initial_state_indices is None: + initial_state_indices = cache["default_indices"] + if output_state_indices is None: + output_state_indices = initial_state_indices + if output is None: + output = cache["output"] - _compiled_kernels_ilp[cache_key]( - h_, - A_log_, - a_, - dt_bias_, - q_, - k_, - v_, - b_, - o_, + cache["compiled"]( + h0_source, + A_log, + a, + dt_bias, + q, + k, + v, + b, + output, + initial_state_indices, + output_state_indices, stream, ) @@ -2547,18 +2636,9 @@ def gated_delta_rule_mtp( if scale is None: scale = 1.0 / math.sqrt(K) - if initial_state_indices is None: - initial_state_indices = torch.arange(B, dtype=torch.int32, device=q.device) - - # Default output indices to read indices - if output_state_indices is None: - output_state_indices = initial_state_indices - elif output_state_indices.dtype != torch.int32: + if output_state_indices is not None and output_state_indices.dtype != torch.int32: output_state_indices = output_state_indices.to(torch.int32) - if output is None: - output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) - # Reshape state to [pool_size * HV, V, K] h0_source = initial_state_source.reshape(pool_size * HV, V, K) @@ -2583,25 +2663,8 @@ def gated_delta_rule_mtp( tile_v = _select_tile_v_for_mtp(B, HV, V, T) - h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) - inter_ = from_dlpack(intermediate_states, assumed_align=32, enable_tvm_ffi=True) - q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) - k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) - v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) - a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) - b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) - A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) - dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) - o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) - h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True) - h0_out_idx_ = from_dlpack( - output_state_indices, assumed_align=32, enable_tvm_ffi=True - ) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - major, _ = torch.cuda.get_device_capability(q.device) - use_packed_fma = major >= 10 + use_packed_fma = _USE_PACKED_FMA cache_key = ( "mtp_bf16", @@ -2622,51 +2685,87 @@ def gated_delta_rule_mtp( use_packed_fma, ) if cache_key not in _compiled_kernels_mtp: - _compiled_kernels_mtp[cache_key] = cute.compile( - run_gdn_decode_bf16state_mtp, - h_, - inter_, - A_log_, - a_, - dt_bias_, - q_, - k_, - v_, - b_, - o_, - h0_idx_, - h0_out_idx_, - softplus_beta, - softplus_threshold, - scale, - HV, - B, - T, - H, - K, - V, - tile_v, - use_qk_l2norm_in_kernel, - disable_state_update, - cache_intermediate_states, - use_packed_fma, - stream, - options="--enable-tvm-ffi --generate-line-info --opt-level 3", + # First call for this shape: allocate default indices/output and do + # dlpack conversions once for compilation. Steady-state calls pass + # torch tensors straight to the compiled callable (tvm-ffi accepts + # either) and reuse these cached defaults when the caller doesn't + # provide their own. + default_indices = torch.arange(B, dtype=torch.int32, device=q.device) + default_output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + + h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) + inter_ = from_dlpack(intermediate_states, assumed_align=32, enable_tvm_ffi=True) + q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) + k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) + v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) + a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) + b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) + A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) + dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) + o_ = from_dlpack(default_output, assumed_align=32, enable_tvm_ffi=True) + h0_idx_ = from_dlpack(default_indices, assumed_align=32, enable_tvm_ffi=True) + h0_out_idx_ = from_dlpack( + default_indices, assumed_align=32, enable_tvm_ffi=True ) - _compiled_kernels_mtp[cache_key]( - h_, - inter_, - A_log_, - a_, - dt_bias_, - q_, - k_, - v_, - b_, - o_, - h0_idx_, - h0_out_idx_, + _compiled_kernels_mtp[cache_key] = { + "compiled": cute.compile( + run_gdn_decode_bf16state_mtp, + h_, + inter_, + A_log_, + a_, + dt_bias_, + q_, + k_, + v_, + b_, + o_, + h0_idx_, + h0_out_idx_, + softplus_beta, + softplus_threshold, + scale, + HV, + B, + T, + H, + K, + V, + tile_v, + use_qk_l2norm_in_kernel, + disable_state_update, + cache_intermediate_states, + use_packed_fma, + stream, + options="--enable-tvm-ffi --generate-line-info --opt-level 3", + ), + "default_indices": default_indices, + "output": default_output, + } + + cache = _compiled_kernels_mtp[cache_key] + + if initial_state_indices is None: + initial_state_indices = cache["default_indices"] + if output_state_indices is None: + output_state_indices = initial_state_indices + if output is None: + output = cache["output"] + + cache["compiled"]( + h0_source, + intermediate_states, + A_log, + a, + dt_bias, + q, + k, + v, + b, + output, + initial_state_indices, + output_state_indices, stream, )