From 28cbec5c24d9a7d1a2760c9dc91c3b63e2f58e92 Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Sat, 7 Feb 2026 13:25:32 +0800 Subject: [PATCH 1/9] perf(gdn): add pooled decode kernel with direct pool indexing and in-kernel negative index handling Add use_pool_indexing constexpr to both small-batch and big-batch pretranspose decode kernels, enabling zero-copy state access directly from the pool via h0_indices, eliminating gather/scatter overhead. Also handle negative pool indices (padding slots) inside the kernel: blocks with negative indices skip computation and write zeros to output, removing the need for host-side torch.where remap (~37us/call savings). Combined effect: K-last decode is 4-5.6% faster than V-last at BS>=4. --- flashinfer/gdn_decode.py | 859 +++++++++++++++++++++++++-------------- 1 file changed, 561 insertions(+), 298 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index e64c231686..0747c7c9b1 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -158,6 +158,7 @@ def gdn_decode_kernel_small_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], + use_pool_indexing: cutlass.Constexpr[bool] = False, ): """Each block uses pipeline to load one batch and vectorized writeback""" @@ -218,189 +219,208 @@ def gdn_decode_kernel_small_batch_pretranspose( cute.arch.barrier() - # Get current batch - gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K) - gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0)) + # Compute state index: use pool indexing if enabled (h0_indices maps batch to pool slot) + if cutlass.const_expr(use_pool_indexing): + pool_idx = h0_indices[i_n] + state_idx = pool_idx * HV + i_hv + else: + pool_idx = 0 # dummy, always valid when not using pool indexing + state_idx = batch_idx - # V 方向分 tiles - gSrc = cute.local_tile( - gSrc_batch, (TILE_V, TILE_K), (None, 0) - ) # (TILE_V, TILE_K, num_v_tiles) + # When pool indexing: skip computation for padding slots (negative indices) + # and write zeros to output. When not pool indexing: always valid (pool_idx=0). + if pool_idx >= 0: + # Get current batch + gSrc_batch = h0_source[(state_idx, None, None)] # (V, K) + gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (state_idx, None, 0)) - # Partition for load - thr_copy_load = tiled_copy_load.get_slice(tidx) + # V 方向分 tiles + gSrc = cute.local_tile( + gSrc_batch, (TILE_V, TILE_K), (None, 0) + ) # (TILE_V, TILE_K, num_v_tiles) - # =================================================================== - # Prefetch: All threads participate in cp.async load - # =================================================================== - start_v_tiles = batch_inner * num_v_tiles_per_block - prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) - for v_tiles in range(start_v_tiles, start_v_tiles + prefetch_count): - stage = (v_tiles - start_v_tiles) % NUM_STAGES + # Partition for load + thr_copy_load = tiled_copy_load.get_slice(tidx) - gSrc_tile = gSrc[(None, None, v_tiles)] - sData_stage = sData[(None, None, stage)] + # =================================================================== + # Prefetch: All threads participate in cp.async load + # =================================================================== + start_v_tiles = batch_inner * num_v_tiles_per_block + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) + for v_tiles in range(start_v_tiles, start_v_tiles + prefetch_count): + stage = (v_tiles - start_v_tiles) % NUM_STAGES - thr_gSrc = thr_copy_load.partition_S(gSrc_tile) - thr_sData = thr_copy_load.partition_D(sData_stage) + gSrc_tile = gSrc[(None, None, v_tiles)] + sData_stage = sData[(None, None, stage)] - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) - # Load q, k into BF16 registers using autovec_copy (contiguous pattern) - q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - cute.autovec_copy(q_tile, r_q_bf16) - cute.autovec_copy(k_tile, r_k_bf16) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - # Convert BF16 to FP32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = cutlass.Float32(r_q_bf16[i]) - r_k[i] = cutlass.Float32(r_k_bf16[i]) + # Load q, k into BF16 registers using autovec_copy (contiguous pattern) + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) - # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV - v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) - cute.autovec_copy(v_tile, r_v_bf16) - for i in cutlass.range_constexpr(vec_size): - sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + # Convert BF16 to FP32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) - cute.arch.barrier() # Ensure all threads finish writing to sV + # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) - # =================================================================== - # Compute g and beta (scalar values) - # =================================================================== - r_g = 0.0 - r_beta = 0.0 - if lane_id == 0: - x = r_a + r_dt_bias - beta_x = softplus_beta * x - softplus_x = 0.0 - - if beta_x <= softplus_threshold: - # softplus(x) = (1/beta) * log(1 + exp(beta*x)) - # Compute in Float32 - exp_beta_x = cute.exp(beta_x, fastmath=True) - log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) - softplus_x = cutlass.Float32( - (cutlass.Float32(1.0) / softplus_beta) * log_result - ) - else: - softplus_x = x + cute.arch.barrier() # Ensure all threads finish writing to sV - # Compute g = exp(A_log) * softplus_x - r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + # =================================================================== + # Compute g and beta (scalar values) + # =================================================================== + r_g = 0.0 + r_beta = 0.0 + if lane_id == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 - # Compute beta = 1 / (1 + exp(-b)) - r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) + if beta_x <= softplus_threshold: + # softplus(x) = (1/beta) * log(1 + exp(beta*x)) + # Compute in Float32 + exp_beta_x = cute.exp(beta_x, fastmath=True) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x - # Store to scalar (Float32) - r_g = cute.exp(r_g_value, fastmath=True) + # Compute g = exp(A_log) * softplus_x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x - r_g = cute.arch.shuffle_sync(r_g, 0) - r_beta = cute.arch.shuffle_sync(r_beta, 0) + # Compute beta = 1 / (1 + exp(-b)) + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) - if use_qk_l2norm: - # Compute L2 norm of q and k - sum_q = 0.0 - sum_k = 0.0 - for i in cutlass.range_constexpr(vec_size): - sum_q += r_q[i] * r_q[i] - sum_k += r_k[i] * r_k[i] - # Warp-level reduction using butterfly shuffle - for offset in [16, 8, 4, 2, 1]: - sum_q += cute.arch.shuffle_sync_bfly( - sum_q, offset=offset, mask=-1, mask_and_clamp=31 - ) - sum_k += cute.arch.shuffle_sync_bfly( - sum_k, offset=offset, mask=-1, mask_and_clamp=31 - ) + # Store to scalar (Float32) + r_g = cute.exp(r_g_value, fastmath=True) - inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) - inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) - for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * inv_norm_q - r_k[i] = r_k[i] * inv_norm_k + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) - # Apply scaling in Float32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * scale + if use_qk_l2norm: + # Compute L2 norm of q and k + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + # Warp-level reduction using butterfly shuffle + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) - # =================================================================== - # Mainloop: All threads participate - # =================================================================== - end_v_tiles = start_v_tiles + num_v_tiles_per_block - for v_tiles in range(start_v_tiles, end_v_tiles): - stage = (v_tiles - start_v_tiles) % NUM_STAGES + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k - # Step 1: Wait for current stage to complete - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + # Apply scaling in Float32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale - # Step 2: Issue async load for next tile (after compute) - next_v_tiles = v_tiles + prefetch_count - if next_v_tiles < end_v_tiles: - next_stage = (next_v_tiles - start_v_tiles) % NUM_STAGES + # =================================================================== + # Mainloop: All threads participate + # =================================================================== + end_v_tiles = start_v_tiles + num_v_tiles_per_block + for v_tiles in range(start_v_tiles, end_v_tiles): + stage = (v_tiles - start_v_tiles) % NUM_STAGES - gSrc_next = gSrc[(None, None, next_v_tiles)] - sData_next = sData[(None, None, next_stage)] + # Step 1: Wait for current stage to complete + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() - thr_gSrc = thr_copy_load.partition_S(gSrc_next) - thr_sData = thr_copy_load.partition_D(sData_next) + # Step 2: Issue async load for next tile (after compute) + next_v_tiles = v_tiles + prefetch_count + if next_v_tiles < end_v_tiles: + next_stage = (next_v_tiles - start_v_tiles) % NUM_STAGES - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + gSrc_next = gSrc[(None, None, next_v_tiles)] + sData_next = sData[(None, None, next_stage)] - # Step 3: Compute using data from current stage (contiguous access pattern) - for row in cutlass.range_constexpr(0, TILE_V, 4): - row_offset = tidx // 32 - sum_hk = 0.0 + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) - # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) - sData_tile = cute.local_tile( - sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) - ) - cute.autovec_copy(sData_tile, r_h) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - for i in cutlass.range_constexpr(vec_size): - r_h[i] = r_h[i] * r_g - sum_hk += r_h[i] * r_k[i] + # Step 3: Compute using data from current stage (contiguous access pattern) + for row in cutlass.range_constexpr(0, TILE_V, 4): + row_offset = tidx // 32 + sum_hk = 0.0 - for offset in [16, 8, 4, 2, 1]: - sum_hk += cute.arch.shuffle_sync_bfly( - sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) ) + cute.autovec_copy(sData_tile, r_h) - v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk - v_new = v_new * r_beta + for i in cutlass.range_constexpr(vec_size): + r_h[i] = r_h[i] * r_g + sum_hk += r_h[i] * r_k[i] - sum_hq = 0.0 - for i in cutlass.range_constexpr(vec_size): - r_h[i] += r_k[i] * v_new - sum_hq += r_h[i] * r_q[i] + for offset in [16, 8, 4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + ) - # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) - gDst_tile = cute.local_tile( - gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) - ) - cute.autovec_copy(r_h, gDst_tile) + v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_new = v_new * r_beta - for offset in [16, 8, 4, 2, 1]: - sum_hq += cute.arch.shuffle_sync_bfly( - sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + sum_hq = 0.0 + for i in cutlass.range_constexpr(vec_size): + r_h[i] += r_k[i] * v_new + sum_hq += r_h[i] * r_q[i] + + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) ) + cute.autovec_copy(r_h, gDst_tile) - o_idx = v_tiles * TILE_V + row + row_offset - if lane_id == 0 and o_idx < V: - sOutput[o_idx] = cutlass.BFloat16(sum_hq) + for offset in [16, 8, 4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + ) - # =================================================================== - # Final writeback: Copy output from shared memory to global memory - # All threads write (V=128, NUM_THREADS=128) - # =================================================================== - cute.arch.barrier() # Ensure all writes to sOutput are complete - if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: - o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + o_idx = v_tiles * TILE_V + row + row_offset + if lane_id == 0 and o_idx < V: + sOutput[o_idx] = cutlass.BFloat16(sum_hq) + + # =================================================================== + # Final writeback: Copy output from shared memory to global memory + # All threads write (V=128, NUM_THREADS=128) + # =================================================================== + cute.arch.barrier() # Ensure all writes to sOutput are complete + if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: + o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + else: + # Padding slot: write zeros to output + start_v_tiles = batch_inner * num_v_tiles_per_block + if ( + tidx >= start_v_tiles * TILE_V + and tidx < (start_v_tiles + num_v_tiles_per_block) * TILE_V + ): + o[(i_n, i_t, i_hv, tidx)] = cutlass.BFloat16(0.0) @cute.kernel @@ -432,6 +452,7 @@ def gdn_decode_kernel_big_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], + use_pool_indexing: cutlass.Constexpr[bool] = False, ): """Each block uses pipeline to load one batch and vectorized writeback""" @@ -489,188 +510,201 @@ def gdn_decode_kernel_big_batch_pretranspose( cute.arch.barrier() - # Get current batch - gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K) - gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0)) + # Compute state index: use pool indexing if enabled (h0_indices maps batch to pool slot) + if cutlass.const_expr(use_pool_indexing): + pool_idx = h0_indices[i_n] + state_idx = pool_idx * HV + i_hv + else: + pool_idx = 0 # dummy, always valid + state_idx = batch_idx - # V 方向分 tiles - gSrc = cute.local_tile( - gSrc_batch, (TILE_V, TILE_K), (None, 0) - ) # (TILE_V, TILE_K, num_v_tiles) + if pool_idx >= 0: + # Get current batch + gSrc_batch = h0_source[(state_idx, None, None)] # (V, K) + gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (state_idx, None, 0)) - # Partition for load - thr_copy_load = tiled_copy_load.get_slice(tidx) + # V 方向分 tiles + gSrc = cute.local_tile( + gSrc_batch, (TILE_V, TILE_K), (None, 0) + ) # (TILE_V, TILE_K, num_v_tiles) - # =================================================================== - # Prefetch: All threads participate in cp.async load - # =================================================================== - prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles) - for v_tiles in range(prefetch_count): - stage = v_tiles % NUM_STAGES + # Partition for load + thr_copy_load = tiled_copy_load.get_slice(tidx) - gSrc_tile = gSrc[(None, None, v_tiles)] - sData_stage = sData[(None, None, stage)] + # =================================================================== + # Prefetch: All threads participate in cp.async load + # =================================================================== + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles) + for v_tiles in range(prefetch_count): + stage = v_tiles % NUM_STAGES - thr_gSrc = thr_copy_load.partition_S(gSrc_tile) - thr_sData = thr_copy_load.partition_D(sData_stage) + gSrc_tile = gSrc[(None, None, v_tiles)] + sData_stage = sData[(None, None, stage)] - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) - # Load q, k into BF16 registers using autovec_copy (contiguous pattern) - q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - cute.autovec_copy(q_tile, r_q_bf16) - cute.autovec_copy(k_tile, r_k_bf16) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - # Convert BF16 to FP32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = cutlass.Float32(r_q_bf16[i]) - r_k[i] = cutlass.Float32(r_k_bf16[i]) + # Load q, k into BF16 registers using autovec_copy (contiguous pattern) + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) - # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV - v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) - cute.autovec_copy(v_tile, r_v_bf16) - for i in cutlass.range_constexpr(vec_size): - sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + # Convert BF16 to FP32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) - cute.arch.barrier() # Ensure all threads finish writing to sV + # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) - # =================================================================== - # Compute g and beta (scalar values) - # =================================================================== - r_g = 0.0 - r_beta = 0.0 - if lane_id == 0: - x = r_a + r_dt_bias - beta_x = softplus_beta * x - softplus_x = 0.0 - - if beta_x <= softplus_threshold: - # softplus(x) = (1/beta) * log(1 + exp(beta*x)) - # Compute in Float32 - exp_beta_x = cute.exp(beta_x, fastmath=True) - log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) - softplus_x = cutlass.Float32( - (cutlass.Float32(1.0) / softplus_beta) * log_result - ) - else: - softplus_x = x + cute.arch.barrier() # Ensure all threads finish writing to sV - # Compute g = exp(A_log) * softplus_x - r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + # =================================================================== + # Compute g and beta (scalar values) + # =================================================================== + r_g = 0.0 + r_beta = 0.0 + if lane_id == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 - # Compute beta = 1 / (1 + exp(-b)) - r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) + if beta_x <= softplus_threshold: + # softplus(x) = (1/beta) * log(1 + exp(beta*x)) + # Compute in Float32 + exp_beta_x = cute.exp(beta_x, fastmath=True) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x - # Store to scalar (Float32) - r_g = cute.exp(r_g_value, fastmath=True) + # Compute g = exp(A_log) * softplus_x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x - r_g = cute.arch.shuffle_sync(r_g, 0) - r_beta = cute.arch.shuffle_sync(r_beta, 0) + # Compute beta = 1 / (1 + exp(-b)) + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) - if use_qk_l2norm: - # Compute L2 norm of q and k - sum_q = 0.0 - sum_k = 0.0 - for i in cutlass.range_constexpr(vec_size): - sum_q += r_q[i] * r_q[i] - sum_k += r_k[i] * r_k[i] - # Warp-level reduction using butterfly shuffle - for offset in [16, 8, 4, 2, 1]: - sum_q += cute.arch.shuffle_sync_bfly( - sum_q, offset=offset, mask=-1, mask_and_clamp=31 - ) - sum_k += cute.arch.shuffle_sync_bfly( - sum_k, offset=offset, mask=-1, mask_and_clamp=31 - ) + # Store to scalar (Float32) + r_g = cute.exp(r_g_value, fastmath=True) - inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) - inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) - for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * inv_norm_q - r_k[i] = r_k[i] * inv_norm_k + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) - # Apply scaling in Float32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * scale + if use_qk_l2norm: + # Compute L2 norm of q and k + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + # Warp-level reduction using butterfly shuffle + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) - # =================================================================== - # Mainloop: All threads participate - # =================================================================== - for v_tiles in range(num_v_tiles): - stage = v_tiles % NUM_STAGES + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k - # Step 1: Wait for current stage to complete - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + # Apply scaling in Float32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale - # Step 2: Issue async load for next tile (after compute) - next_v_tiles = v_tiles + prefetch_count - if next_v_tiles < num_v_tiles: - next_stage = next_v_tiles % NUM_STAGES + # =================================================================== + # Mainloop: All threads participate + # =================================================================== + for v_tiles in range(num_v_tiles): + stage = v_tiles % NUM_STAGES - gSrc_next = gSrc[(None, None, next_v_tiles)] - sData_next = sData[(None, None, next_stage)] + # Step 1: Wait for current stage to complete + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() - thr_gSrc = thr_copy_load.partition_S(gSrc_next) - thr_sData = thr_copy_load.partition_D(sData_next) + # Step 2: Issue async load for next tile (after compute) + next_v_tiles = v_tiles + prefetch_count + if next_v_tiles < num_v_tiles: + next_stage = next_v_tiles % NUM_STAGES - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + gSrc_next = gSrc[(None, None, next_v_tiles)] + sData_next = sData[(None, None, next_stage)] - # Step 3: Compute using data from current stage (contiguous access pattern) - for row in cutlass.range_constexpr(0, TILE_V, 4): - row_offset = tidx // 32 - sum_hk = 0.0 + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) - # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) - sData_tile = cute.local_tile( - sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) - ) - cute.autovec_copy(sData_tile, r_h) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - for i in cutlass.range_constexpr(vec_size): - r_h[i] = r_h[i] * r_g - sum_hk += r_h[i] * r_k[i] + # Step 3: Compute using data from current stage (contiguous access pattern) + for row in cutlass.range_constexpr(0, TILE_V, 4): + row_offset = tidx // 32 + sum_hk = 0.0 - for offset in [16, 8, 4, 2, 1]: - sum_hk += cute.arch.shuffle_sync_bfly( - sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) ) + cute.autovec_copy(sData_tile, r_h) - v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk - v_new = v_new * r_beta + for i in cutlass.range_constexpr(vec_size): + r_h[i] = r_h[i] * r_g + sum_hk += r_h[i] * r_k[i] - sum_hq = 0.0 - for i in cutlass.range_constexpr(vec_size): - r_h[i] += r_k[i] * v_new - sum_hq += r_h[i] * r_q[i] + for offset in [16, 8, 4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + ) - # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) - gDst_tile = cute.local_tile( - gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) - ) - cute.autovec_copy(r_h, gDst_tile) + v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_new = v_new * r_beta - for offset in [16, 8, 4, 2, 1]: - sum_hq += cute.arch.shuffle_sync_bfly( - sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + sum_hq = 0.0 + for i in cutlass.range_constexpr(vec_size): + r_h[i] += r_k[i] * v_new + sum_hq += r_h[i] * r_q[i] + + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) ) + cute.autovec_copy(r_h, gDst_tile) - o_idx = v_tiles * TILE_V + row + row_offset - if lane_id == 0 and o_idx < V: - sOutput[o_idx] = cutlass.BFloat16(sum_hq) + for offset in [16, 8, 4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + ) - # =================================================================== - # Final writeback: Copy output from shared memory to global memory - # All threads write (V=128, NUM_THREADS=128) - # =================================================================== - cute.arch.barrier() # Ensure all writes to sOutput are complete + o_idx = v_tiles * TILE_V + row + row_offset + if lane_id == 0 and o_idx < V: + sOutput[o_idx] = cutlass.BFloat16(sum_hq) + + # =================================================================== + # Final writeback: Copy output from shared memory to global memory + # All threads write (V=128, NUM_THREADS=128) + # =================================================================== + cute.arch.barrier() # Ensure all writes to sOutput are complete - if tidx < V: - o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + if tidx < V: + o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + else: + # Padding slot (negative pool index): write zeros to output + if tidx < V: + o[(i_n, i_t, i_hv, tidx)] = cutlass.BFloat16(0.0) @cute.jit @@ -698,10 +732,11 @@ def run_gdn_decode_kernel_small_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], - stream: cuda.CUstream, + use_pool_indexing: cutlass.Constexpr[bool] = False, + stream: cuda.CUstream = None, ): """Launch original pipelined kernel for small batch pretranspose.""" - # h0_source: (B*HV, V, K) + # h0_source: (B*HV, V, K) or (pool_size*HV, V, K) when use_pool_indexing=True batch_size, v_dim, k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], @@ -741,6 +776,9 @@ def run_gdn_decode_kernel_small_batch_pretranspose( # sOutput: V * 2 bytes (BFloat16) smem_bytes = 4 * TILE_V * TILE_K * NUM_STAGES + 4 * k_dim + 2 * v_dim + 32 + # Grid size: use B*HV (actual batch) not h0_source.shape[0] (which may be pool_size*HV) + grid_batch = B * HV + gdn_decode_kernel_small_batch_pretranspose( tiled_copy_load, h0_source, @@ -769,8 +807,9 @@ def run_gdn_decode_kernel_small_batch_pretranspose( use_initial_state, use_qk_l2norm, is_varlen, + use_pool_indexing, ).launch( - grid=(batch_size * NUM_BLOCKS_PER_STATE, 1, 1), + grid=(grid_batch * NUM_BLOCKS_PER_STATE, 1, 1), block=[NUM_THREADS, 1, 1], smem=smem_bytes, stream=stream, @@ -802,9 +841,10 @@ def run_gdn_decode_kernel_big_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], - stream: cuda.CUstream, + use_pool_indexing: cutlass.Constexpr[bool] = False, + stream: cuda.CUstream = None, ): - # h0_source: (B*HV, V, K) + # h0_source: (B*HV, V, K) or (pool_size*HV, V, K) when use_pool_indexing=True batch_size, v_dim, k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], @@ -850,6 +890,9 @@ def run_gdn_decode_kernel_big_batch_pretranspose( # sOutput: V * 2 bytes (BFloat16) smem_bytes = 4 * TILE_V * TILE_K * NUM_STAGES + 4 * k_dim + 2 * v_dim + 32 + # Grid size: use B*HV (actual batch) not h0_source.shape[0] (which may be pool_size*HV) + grid_batch = B * HV + gdn_decode_kernel_big_batch_pretranspose( tiled_copy_load, h0_source, @@ -878,8 +921,9 @@ def run_gdn_decode_kernel_big_batch_pretranspose( use_initial_state, use_qk_l2norm, is_varlen, + use_pool_indexing, ).launch( - grid=(batch_size, 1, 1), + grid=(grid_batch, 1, 1), block=[NUM_THREADS, 1, 1], smem=smem_bytes, stream=stream, @@ -1105,6 +1149,225 @@ def gated_delta_rule_decode_pretranspose( return output, state +@functools.cache +def _get_compiled_decode_kernel_pooled( + B: int, + T: int, + H: int, + HV: int, + K: int, + V: int, + pool_size: int, + dtype: torch.dtype, + scale: float, + use_qk_l2norm: bool, +): + """Cache compiled kernel for pooled pretranspose decode (different constexpr → separate cache).""" + return {} + + +@flashinfer_api +def gated_delta_rule_decode_pretranspose_pooled( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor, + initial_state_indices: torch.Tensor, + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, + use_qk_l2norm: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Gated Delta Rule Decode kernel with direct pool indexing (zero-copy). + + Like ``gated_delta_rule_decode_pretranspose`` but reads/writes state directly + from/to the state pool using ``initial_state_indices``, eliminating the need + for gather before and scatter (index_copy_) after the kernel call. + + Args: + q (torch.Tensor): + Current query of shape ``[B, 1, H, K]``. Must be float16/bfloat16. + k (torch.Tensor): + Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16. + v (torch.Tensor): + Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. + initial_state (torch.Tensor): + State pool of shape ``[pool_size, HV, V, K]`` (K-last / V-major layout). + Must be float32. Will be updated in-place at the indexed positions. + initial_state_indices (torch.Tensor): + Indices mapping each batch element to its pool slot, shape ``[B]``. + Must be int32. Negative indices are treated as padding (skipped). + A_log (torch.Tensor): + Log decay parameter of shape ``[HV]``. Must be float32. + a (torch.Tensor): + Input-dependent decay of shape ``[B, 1, HV]``. Must be float16/bfloat16. + dt_bias (torch.Tensor): + Decay bias of shape ``[HV]``. Must be bfloat16 or float32. + b (torch.Tensor): + Update gate (beta) input of shape ``[B, 1, HV]``. Must be float16/bfloat16. + scale (Optional[float]): + Scale factor for queries. If None, defaults to ``1 / sqrt(K)``. + output (Optional[torch.Tensor]): + Pre-allocated output tensor of shape ``[B, 1, HV, V]``. + If None, will be allocated automatically. + use_qk_l2norm (bool): + Whether to apply L2 normalization to q and k. Default: ``True``. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - output: Output tensor of shape ``[B, 1, HV, V]`` + - initial_state: The same state pool tensor (updated in-place) + + Note: + - Requires SM90 (Hopper) architecture + - State is updated in-place at positions given by initial_state_indices + - K and V must be multiples of 4 for vectorized loads + - State layout is V-major: [pool_size, HV, V, K] + - Unlike ``gated_delta_rule_decode_pretranspose``, no gather/scatter needed + """ + # Validate input shapes + B, T, H, K = q.shape + assert T == 1, f"Decode only supports T=1, got T={T}" + _, _, HV, V = v.shape + pool_size = initial_state.shape[0] + + # Validate state shape + assert initial_state.shape == (pool_size, HV, V, K), ( + f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], got {initial_state.shape}" + ) + + # Validate indices + assert initial_state_indices.shape == (B,), ( + f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}" + ) + assert initial_state_indices.dtype == torch.int32, ( + f"initial_state_indices must be int32, got {initial_state_indices.dtype}" + ) + + # Validate K and V constraints + assert K >= 128, f"K must be at least 128, got K={K}" + assert V >= 128, f"V must be at least 128, got V={V}" + assert V % TILE_V == 0, ( + f"V must be divisible by {TILE_V} to prevent out-of-bounds access, got V={V}" + ) + + # Validate dtypes + assert q.dtype in (torch.float16, torch.bfloat16), ( + f"q must be float16/bfloat16, got {q.dtype}" + ) + assert initial_state.dtype == torch.float32, ( + f"initial_state must be float32, got {initial_state.dtype}" + ) + assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" + + # Set default scale + if scale is None: + scale = K**-0.5 + + # Allocate output if not provided + output_provided = output is not None + target_dtype = output.dtype if output_provided else q.dtype + + if output is None: + output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device) + + # Reshape state pool from [pool_size, HV, V, K] to [pool_size*HV, V, K] + h0_source = initial_state.reshape(pool_size * HV, V, K) + + # Compile kernel with TVM FFI (cached, separate cache from non-pooled version) + cache_key = (B, T, H, HV, K, V, pool_size, q.dtype, scale, use_qk_l2norm) + cache = _get_compiled_decode_kernel_pooled(*cache_key) + + # Get or create cu_seqlens (cached per config) + if "cu_seqlens" not in cache or cache["cu_seqlens"].device != q.device: + cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) + cu_seqlens = cache["cu_seqlens"] + + if "compiled" not in cache: + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Convert tensors to CuTe format for compilation only + h0_source_tensor = from_dlpack(h0_source, assumed_align=16) + A_log_tensor = from_dlpack(A_log, assumed_align=16) + a_tensor = from_dlpack(a, assumed_align=16) + dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16) + q_tensor = from_dlpack(q, assumed_align=16) + k_tensor = from_dlpack(k, assumed_align=16) + v_tensor = from_dlpack(v, assumed_align=16) + b_tensor = from_dlpack(b, assumed_align=16) + o_tensor = from_dlpack(output, assumed_align=16) + h0_indices_tensor = from_dlpack(initial_state_indices, assumed_align=16) + cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16) + + # Choose kernel based on batch size + if B <= 32: + run_func = run_gdn_decode_kernel_small_batch_pretranspose + else: + run_func = run_gdn_decode_kernel_big_batch_pretranspose + + # Use TVM FFI to reduce runtime overhead + compiled = cute.compile( + run_func, + h0_source_tensor, + A_log_tensor, + a_tensor, + dt_bias_tensor, + q_tensor, + k_tensor, + v_tensor, + b_tensor, + o_tensor, + h0_indices_tensor, + cu_seqlens_tensor, + softplus_beta=1.0, + softplus_threshold=20.0, + scale=scale, + HV=HV, + B=B, + T=T, + H=H, + K=K, + V=V, + use_initial_state=True, + use_qk_l2norm=use_qk_l2norm, + is_varlen=False, + use_pool_indexing=True, + stream=stream, + options="--enable-tvm-ffi", + ) + cache["compiled"] = compiled + else: + compiled = cache["compiled"] + + # Run kernel directly with PyTorch tensors (no from_dlpack needed) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + cache["compiled"]( + h0_source, + A_log, + a, + dt_bias, + q, + k, + v, + b, + output, + initial_state_indices, + cu_seqlens, + stream, + ) + + # State is updated in-place via pool indexing — no copy needed + + # Convert output to target dtype if needed (kernel outputs bfloat16) + if output.dtype != target_dtype: + output = output.to(target_dtype) + + return output, initial_state + + # ============================================================================ # NONTRANSPOSE Version Kernels - K-major layout [pool, HV, K, V] # ============================================================================ From 290a390b3dcf91ff548a27170e06058c7869f9da Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Sat, 7 Feb 2026 18:55:37 +0800 Subject: [PATCH 2/9] chore(gdn): remove dead expressions and commented-out debug prints in decode launchers --- flashinfer/gdn_decode.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 0747c7c9b1..cbf1e09eae 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -737,7 +737,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose( ): """Launch original pipelined kernel for small batch pretranspose.""" # h0_source: (B*HV, V, K) or (pool_size*HV, V, K) when use_pool_indexing=True - batch_size, v_dim, k_dim = ( + _, v_dim, k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], h0_source.layout.shape[2], @@ -760,7 +760,6 @@ def run_gdn_decode_kernel_small_batch_pretranspose( tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * batch_size * 4 / 1024 / 1024 vec_size = ( TILE_K // 32 @@ -845,7 +844,7 @@ def run_gdn_decode_kernel_big_batch_pretranspose( stream: cuda.CUstream = None, ): # h0_source: (B*HV, V, K) or (pool_size*HV, V, K) when use_pool_indexing=True - batch_size, v_dim, k_dim = ( + _, v_dim, k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], h0_source.layout.shape[2], @@ -868,18 +867,11 @@ def run_gdn_decode_kernel_big_batch_pretranspose( tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) num_v_tiles = cute.ceil_div(v_dim, TILE_V) - v_dim * k_dim * batch_size * 4 / 1024 / 1024 vec_size = ( TILE_K // 32 ) # Each thread in a warp processes this many elements (always 4 for TILE_K=128) - # print(f"Batched CP.ASYNC Load + Store (bypass L1 cache)") - # print(f" {batch_size} batches x {v_dim}x{k_dim} matrices") - # print(f" Tile: {TILE_V}x{TILE_K}, {num_v_tiles} tiles/batch") - # print(f" Threads: {NUM_THREADS} ({NUM_THREADS // 32} warps), vec_size: {vec_size}") - # print(f" Total: {total_data_mb:.1f} MB\n") - # Create SMEM layout smem_layout_staged = cute.make_layout( (TILE_V, TILE_K, NUM_STAGES), stride=(TILE_K, 1, TILE_V * TILE_K) From 452f922b2f1d29878fc8c09da9657e7bcfed6305 Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Sat, 7 Feb 2026 20:13:17 +0800 Subject: [PATCH 3/9] refactor(gdn): merge pooled and non-pooled pretranspose decode into single function Consolidate gated_delta_rule_decode_pretranspose_pooled into gated_delta_rule_decode_pretranspose by adding an optional state_indices parameter. When state_indices is provided, the kernel uses pool-indexed (zero-copy) mode; otherwise it uses direct 1:1 batch-to-state mapping. This eliminates ~175 lines of duplicated Python wrapper code while the underlying CUDA kernels remain unchanged. The compiled kernel cache key now includes pool_size and use_pool_indexing to ensure correct cache separation between the two modes. --- flashinfer/gdn_decode.py | 319 ++++++++++----------------------------- 1 file changed, 78 insertions(+), 241 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index cbf1e09eae..0d5459c83b 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -935,28 +935,20 @@ def _get_compiled_decode_kernel( HV: int, K: int, V: int, + pool_size: int, dtype: torch.dtype, scale: float, use_qk_l2norm: bool, + use_pool_indexing: bool, ): - """Cache compiled kernel for given configuration (pretranspose version).""" - # This will be populated on first call - return {} - + """Cache compiled kernel for given configuration (pretranspose version). -@functools.cache -def _get_compiled_decode_kernel_nontranspose( - B: int, - T: int, - H: int, - HV: int, - K: int, - V: int, - dtype: torch.dtype, - scale: float, - use_qk_l2norm: bool, -): - """Cache compiled kernel for given configuration (nontranspose version).""" + When ``use_pool_indexing=True``, the kernel reads/writes state from a shared + pool using ``state_indices``. Because ``use_pool_indexing`` is a + ``cutlass.Constexpr``, the two modes produce different compiled CUDA code and + must have separate cache entries (ensured by including ``pool_size`` and + ``use_pool_indexing`` in the key). + """ # This will be populated on first call return {} @@ -974,12 +966,18 @@ def gated_delta_rule_decode_pretranspose( scale: Optional[float] = None, output: Optional[torch.Tensor] = None, use_qk_l2norm: bool = True, + state_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Gated Delta Rule Decode kernel for single-token generation. This implements the decode phase of gated delta rule linear attention, processing one token at a time and updating the recurrent state. + When ``state_indices`` is provided, the kernel reads/writes state directly + from/to the state pool using indirect indexing (zero-copy pooled mode), + eliminating the need for gather before and scatter after the kernel call. + Negative indices are treated as padding and skipped. + Args: q (torch.Tensor): Current query of shape ``[B, 1, H, K]``. Must be float16/bfloat16. @@ -988,8 +986,9 @@ def gated_delta_rule_decode_pretranspose( v (torch.Tensor): Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. state (torch.Tensor): - Current state of shape ``[B, HV, V, K]`` (v-major layout). - Must be float32. Will be updated in-place. + Current state of shape ``[B, HV, V, K]`` or, when ``state_indices`` + is provided, the full state pool of shape ``[pool_size, HV, V, K]`` + (V-major / K-last layout). Must be float32. Updated in-place. A_log (torch.Tensor): Log decay parameter of shape ``[HV]``. Must be float32. a (torch.Tensor): @@ -1005,27 +1004,48 @@ def gated_delta_rule_decode_pretranspose( If None, will be allocated automatically. use_qk_l2norm (bool): Whether to apply L2 normalization to q and k. Default: ``True``. + state_indices (Optional[torch.Tensor]): + When provided, enables pool-indexed (zero-copy) mode. Shape ``[B]``, + dtype ``int32``. Each element maps a batch element to a slot in the + state pool. Negative values are treated as padding (output zeroed, + state not updated). Default: ``None`` (direct 1:1 batch mapping). Returns: Tuple[torch.Tensor, torch.Tensor]: - output: Output tensor of shape ``[B, 1, HV, V]`` - - state: Updated state tensor of shape ``[B, HV, V, K]`` + - state: The (updated) state tensor Note: - Requires SM90 (Hopper) architecture - State is updated in-place - K and V must be multiples of 4 for vectorized loads - - State layout is v-major: [B, HV, V, K] + - State layout is V-major: [*, HV, V, K] """ + use_pool_indexing = state_indices is not None + # Validate input shapes B, T, H, K = q.shape assert T == 1, f"Decode only supports T=1, got T={T}" _, _, HV, V = v.shape + pool_size = state.shape[0] # Validate state shape - assert state.shape == (B, HV, V, K), ( - f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}" + assert state.shape[1:] == (HV, V, K), ( + f"Expected state shape [*, HV={HV}, V={V}, K={K}], got {state.shape}" ) + if not use_pool_indexing: + assert state.shape[0] == B, ( + f"Without state_indices, state dim-0 must equal B={B}, got {state.shape[0]}" + ) + + # Validate indices (pooled mode) + if use_pool_indexing: + assert state_indices.shape == (B,), ( + f"Expected state_indices shape [{B}], got {state_indices.shape}" + ) + assert state_indices.dtype == torch.int32, ( + f"state_indices must be int32, got {state_indices.dtype}" + ) # Validate K and V constraints assert K >= 128, f"K must be at least 128, got K={K}" @@ -1054,18 +1074,35 @@ def gated_delta_rule_decode_pretranspose( # Kernel outputs bfloat16, allocate in that dtype first output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device) - # Convert state from [B, HV, V, K] to [B*HV, V, K] for kernel - h0_source = state.reshape(B * HV, V, K) + # Convert state from [pool_size, HV, V, K] to [pool_size*HV, V, K] for kernel + h0_source = state.reshape(pool_size * HV, V, K) # Compile kernel with TVM FFI (cached) - cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm) + cache_key = ( + B, + T, + H, + HV, + K, + V, + pool_size, + q.dtype, + scale, + use_qk_l2norm, + use_pool_indexing, + ) cache = _get_compiled_decode_kernel(*cache_key) # Get or create h0_indices and cu_seqlens (cached per config) - if "h0_indices" not in cache or cache["h0_indices"].device != q.device: - cache["h0_indices"] = torch.zeros(B, dtype=torch.int32, device=q.device) + if use_pool_indexing: + h0_indices = state_indices + else: + if "h0_indices" not in cache or cache["h0_indices"].device != q.device: + cache["h0_indices"] = torch.zeros(B, dtype=torch.int32, device=q.device) + h0_indices = cache["h0_indices"] + + if "cu_seqlens" not in cache or cache["cu_seqlens"].device != q.device: cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) - h0_indices = cache["h0_indices"] cu_seqlens = cache["cu_seqlens"] if "compiled" not in cache: @@ -1116,6 +1153,7 @@ def gated_delta_rule_decode_pretranspose( use_initial_state=True, use_qk_l2norm=use_qk_l2norm, is_varlen=False, + use_pool_indexing=use_pool_indexing, stream=stream, options="--enable-tvm-ffi", ) @@ -1129,9 +1167,10 @@ def gated_delta_rule_decode_pretranspose( h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream ) - # Copy state back only if state was not contiguous + # Copy state back only if non-pooled mode and state was not contiguous # (if contiguous, reshape returns a view and kernel updated state in-place) - if not state.is_contiguous(): + # In pooled mode, state is always updated in-place via pool indexing. + if not use_pool_indexing and not state.is_contiguous(): state.copy_(h0_source.reshape(B, HV, V, K)) # Convert output to target dtype if needed (kernel outputs bfloat16) @@ -1141,230 +1180,28 @@ def gated_delta_rule_decode_pretranspose( return output, state +# ============================================================================ +# NONTRANSPOSE Version Kernels - K-major layout [pool, HV, K, V] +# ============================================================================ + + @functools.cache -def _get_compiled_decode_kernel_pooled( +def _get_compiled_decode_kernel_nontranspose( B: int, T: int, H: int, HV: int, K: int, V: int, - pool_size: int, dtype: torch.dtype, scale: float, use_qk_l2norm: bool, ): - """Cache compiled kernel for pooled pretranspose decode (different constexpr → separate cache).""" + """Cache compiled kernel for given configuration (nontranspose version).""" + # This will be populated on first call return {} -@flashinfer_api -def gated_delta_rule_decode_pretranspose_pooled( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - initial_state: torch.Tensor, - initial_state_indices: torch.Tensor, - A_log: torch.Tensor, - a: torch.Tensor, - dt_bias: torch.Tensor, - b: torch.Tensor, - scale: Optional[float] = None, - output: Optional[torch.Tensor] = None, - use_qk_l2norm: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Gated Delta Rule Decode kernel with direct pool indexing (zero-copy). - - Like ``gated_delta_rule_decode_pretranspose`` but reads/writes state directly - from/to the state pool using ``initial_state_indices``, eliminating the need - for gather before and scatter (index_copy_) after the kernel call. - - Args: - q (torch.Tensor): - Current query of shape ``[B, 1, H, K]``. Must be float16/bfloat16. - k (torch.Tensor): - Current key of shape ``[B, 1, H, K]``. Must be float16/bfloat16. - v (torch.Tensor): - Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. - initial_state (torch.Tensor): - State pool of shape ``[pool_size, HV, V, K]`` (K-last / V-major layout). - Must be float32. Will be updated in-place at the indexed positions. - initial_state_indices (torch.Tensor): - Indices mapping each batch element to its pool slot, shape ``[B]``. - Must be int32. Negative indices are treated as padding (skipped). - A_log (torch.Tensor): - Log decay parameter of shape ``[HV]``. Must be float32. - a (torch.Tensor): - Input-dependent decay of shape ``[B, 1, HV]``. Must be float16/bfloat16. - dt_bias (torch.Tensor): - Decay bias of shape ``[HV]``. Must be bfloat16 or float32. - b (torch.Tensor): - Update gate (beta) input of shape ``[B, 1, HV]``. Must be float16/bfloat16. - scale (Optional[float]): - Scale factor for queries. If None, defaults to ``1 / sqrt(K)``. - output (Optional[torch.Tensor]): - Pre-allocated output tensor of shape ``[B, 1, HV, V]``. - If None, will be allocated automatically. - use_qk_l2norm (bool): - Whether to apply L2 normalization to q and k. Default: ``True``. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - output: Output tensor of shape ``[B, 1, HV, V]`` - - initial_state: The same state pool tensor (updated in-place) - - Note: - - Requires SM90 (Hopper) architecture - - State is updated in-place at positions given by initial_state_indices - - K and V must be multiples of 4 for vectorized loads - - State layout is V-major: [pool_size, HV, V, K] - - Unlike ``gated_delta_rule_decode_pretranspose``, no gather/scatter needed - """ - # Validate input shapes - B, T, H, K = q.shape - assert T == 1, f"Decode only supports T=1, got T={T}" - _, _, HV, V = v.shape - pool_size = initial_state.shape[0] - - # Validate state shape - assert initial_state.shape == (pool_size, HV, V, K), ( - f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], got {initial_state.shape}" - ) - - # Validate indices - assert initial_state_indices.shape == (B,), ( - f"Expected initial_state_indices shape [{B}], got {initial_state_indices.shape}" - ) - assert initial_state_indices.dtype == torch.int32, ( - f"initial_state_indices must be int32, got {initial_state_indices.dtype}" - ) - - # Validate K and V constraints - assert K >= 128, f"K must be at least 128, got K={K}" - assert V >= 128, f"V must be at least 128, got V={V}" - assert V % TILE_V == 0, ( - f"V must be divisible by {TILE_V} to prevent out-of-bounds access, got V={V}" - ) - - # Validate dtypes - assert q.dtype in (torch.float16, torch.bfloat16), ( - f"q must be float16/bfloat16, got {q.dtype}" - ) - assert initial_state.dtype == torch.float32, ( - f"initial_state must be float32, got {initial_state.dtype}" - ) - assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" - - # Set default scale - if scale is None: - scale = K**-0.5 - - # Allocate output if not provided - output_provided = output is not None - target_dtype = output.dtype if output_provided else q.dtype - - if output is None: - output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device) - - # Reshape state pool from [pool_size, HV, V, K] to [pool_size*HV, V, K] - h0_source = initial_state.reshape(pool_size * HV, V, K) - - # Compile kernel with TVM FFI (cached, separate cache from non-pooled version) - cache_key = (B, T, H, HV, K, V, pool_size, q.dtype, scale, use_qk_l2norm) - cache = _get_compiled_decode_kernel_pooled(*cache_key) - - # Get or create cu_seqlens (cached per config) - if "cu_seqlens" not in cache or cache["cu_seqlens"].device != q.device: - cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) - cu_seqlens = cache["cu_seqlens"] - - if "compiled" not in cache: - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - # Convert tensors to CuTe format for compilation only - h0_source_tensor = from_dlpack(h0_source, assumed_align=16) - A_log_tensor = from_dlpack(A_log, assumed_align=16) - a_tensor = from_dlpack(a, assumed_align=16) - dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16) - q_tensor = from_dlpack(q, assumed_align=16) - k_tensor = from_dlpack(k, assumed_align=16) - v_tensor = from_dlpack(v, assumed_align=16) - b_tensor = from_dlpack(b, assumed_align=16) - o_tensor = from_dlpack(output, assumed_align=16) - h0_indices_tensor = from_dlpack(initial_state_indices, assumed_align=16) - cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16) - - # Choose kernel based on batch size - if B <= 32: - run_func = run_gdn_decode_kernel_small_batch_pretranspose - else: - run_func = run_gdn_decode_kernel_big_batch_pretranspose - - # Use TVM FFI to reduce runtime overhead - compiled = cute.compile( - run_func, - h0_source_tensor, - A_log_tensor, - a_tensor, - dt_bias_tensor, - q_tensor, - k_tensor, - v_tensor, - b_tensor, - o_tensor, - h0_indices_tensor, - cu_seqlens_tensor, - softplus_beta=1.0, - softplus_threshold=20.0, - scale=scale, - HV=HV, - B=B, - T=T, - H=H, - K=K, - V=V, - use_initial_state=True, - use_qk_l2norm=use_qk_l2norm, - is_varlen=False, - use_pool_indexing=True, - stream=stream, - options="--enable-tvm-ffi", - ) - cache["compiled"] = compiled - else: - compiled = cache["compiled"] - - # Run kernel directly with PyTorch tensors (no from_dlpack needed) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - cache["compiled"]( - h0_source, - A_log, - a, - dt_bias, - q, - k, - v, - b, - output, - initial_state_indices, - cu_seqlens, - stream, - ) - - # State is updated in-place via pool indexing — no copy needed - - # Convert output to target dtype if needed (kernel outputs bfloat16) - if output.dtype != target_dtype: - output = output.to(target_dtype) - - return output, initial_state - - -# ============================================================================ -# NONTRANSPOSE Version Kernels - K-major layout [pool, HV, K, V] -# ============================================================================ - - @cute.kernel def gdn_decode_kernel_small_batch_nontranspose( tiled_copy_load: cute.TiledCopy, From 27998076ce76d69b91088e54cfef55a2de20ad5e Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Sat, 7 Feb 2026 20:40:28 +0800 Subject: [PATCH 4/9] fix(gdn): add contiguity assertion for pooled decode state When using pool indexing (state_indices), a non-contiguous state tensor could silently produce incorrect results because the kernel assumes contiguous memory layout for pointer arithmetic. Add an explicit assertion to catch this early. --- flashinfer/gdn_decode.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 0d5459c83b..86e1569eed 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -1040,6 +1040,10 @@ def gated_delta_rule_decode_pretranspose( # Validate indices (pooled mode) if use_pool_indexing: + assert state.is_contiguous(), ( + "state must be contiguous when using pool indexing (state_indices); " + "a non-contiguous tensor may silently produce incorrect results" + ) assert state_indices.shape == (B,), ( f"Expected state_indices shape [{B}], got {state_indices.shape}" ) From a92052a8e97150a154f6af160827c375334c66a6 Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Wed, 25 Feb 2026 23:29:40 +0800 Subject: [PATCH 5/9] test(gdn): add pooled decode tests for zero-copy pool indexing Add comprehensive tests for gated_delta_rule_decode_pretranspose with pool indexing (state_indices parameter): - Test 1: Pooled decode with negative indices (~20% padding) - Test 2: sglang forward_decode calling pattern (unique indices + PAD_SLOT_ID) - Test 3: Pooled vs non-pooled equivalence with identity mapping - Test 4: All-padding batch (output zeros, pool state unchanged) All tests verify output and state against per-sample reference implementation. AI-assisted. --- tests/gdn/test_decode_pooled.py | 479 ++++++++++++++++++++++++++++++++ 1 file changed, 479 insertions(+) create mode 100644 tests/gdn/test_decode_pooled.py diff --git a/tests/gdn/test_decode_pooled.py b/tests/gdn/test_decode_pooled.py new file mode 100644 index 0000000000..b0c9fd0080 --- /dev/null +++ b/tests/gdn/test_decode_pooled.py @@ -0,0 +1,479 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch +import math +import random + +from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose +from flashinfer.utils import get_compute_capability + +try: + from .reference_delta_rule import decode_delta_rule +except ImportError: + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent)) + from reference_delta_rule import decode_delta_rule + + +def _skip_if_not_sm90_or_later(): + """Skip test if not Hopper (SM90+) or Blackwell (SM100+) architecture.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + cc = get_compute_capability(torch.device("cuda")) + if cc[0] not in [9, 10, 11, 12]: + pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}") + + +def _verify_pooled_decode_against_reference( + batch_size: int, + pool_size: int, + num_heads: int, + head_dim: int, + state_indices: torch.Tensor, + dtype_torch: torch.dtype, + seed: int = 42, +): + """ + Core verification logic: run pooled decode kernel and compare per-sample + against the reference implementation. + + Returns (output, state_pool, initial_state_pool) for further assertions. + """ + device = torch.device("cuda") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Inputs: q, k, v [B, 1, H, D] + q = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + k = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + v = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + + # GDN params + A_log = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 + b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 + + # State pool: [pool_size, HV, V, K] (K-last layout) + state_pool = torch.randn( + pool_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device + ) + initial_state_pool = state_pool.clone() + + # Run kernel + output, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=state_pool, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_indices=state_indices, + use_qk_l2norm=True, + ) + torch.cuda.synchronize() + + # Verify per-sample against reference + valid_mask = state_indices >= 0 + invalid_mask = state_indices < 0 + + # Case A: Padding slots (index < 0) must produce zero output + if invalid_mask.any(): + padded_output = output[invalid_mask] + assert torch.all(padded_output == 0), "Padding slots must produce zero output" + + # Case B: Valid slots — compare against reference per sample + valid_indices_local = torch.where(valid_mask)[0].cpu().numpy() + + for i in valid_indices_local: + pool_idx = state_indices[i].item() + + q_i = q[i].float() # [1, H, D] + k_i = k[i].float() + v_i = v[i].float() + a_i = a[i].float() # [1, H] + b_i = b[i].float() + + # Reference expects [B, H, K, V] (K-major), kernel has [B, H, V, K] (V-major) + init_s_i = ( + initial_state_pool[pool_idx].transpose(-2, -1).contiguous().unsqueeze(0) + ) + + ref_o, ref_s = decode_delta_rule( + q_i, + k_i, + v_i, + init_s_i, + A_log=A_log, + a=a_i, + dt_bias=dt_bias, + b=b_i, + scale_factor=1.0 / math.sqrt(head_dim), + use_l2_norm=True, + ) + + # Verify output + out_i = output[i].float() + torch.testing.assert_close( + out_i.squeeze(0), ref_o.squeeze(0).to(out_i.device), atol=1e-2, rtol=1e-2 + ) + + # Verify state update (kernel: [H, V, K], ref: [1, H, K, V]) + ref_s_transposed = ref_s.squeeze(0).transpose(-2, -1) + current_pool_state = state_pool[pool_idx] + torch.testing.assert_close( + current_pool_state, ref_s_transposed.to(device), atol=1e-2, rtol=1e-2 + ) + + # Case C: Untouched pool slots must remain unchanged + used_indices = state_indices[valid_mask].unique() + touched_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) + if len(used_indices) > 0: + touched_mask[used_indices.long()] = True + + untouched_states_final = state_pool[~touched_mask] + untouched_states_initial = initial_state_pool[~touched_mask] + + if len(untouched_states_final) > 0: + torch.testing.assert_close( + untouched_states_final, + untouched_states_initial, + msg="Untouched pool states should not change", + ) + + return output, state_pool, initial_state_pool + + +# ============================================================================ +# Test 1: Basic pooled decode with negative indices +# ============================================================================ + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) +@pytest.mark.parametrize("pool_size_multiplier", [2]) +def test_decode_pooled_with_negative_indices( + dtype, batch_size, pool_size_multiplier, seed=42 +): + """ + Test pooled decode with state_indices, including negative indices for padding. + 20% of batch elements are randomly masked as padding (index = -1). + Valid indices are randomly scattered across the pool. + """ + _skip_if_not_sm90_or_later() + + device = torch.device("cuda") + dtype_torch = getattr(torch, dtype) + num_heads = 16 + head_dim = 128 + pool_size = batch_size * pool_size_multiplier + + # Create indices with ~20% padding + random.seed(seed) + torch.random.manual_seed(seed) + state_indices = torch.arange(batch_size, dtype=torch.int32, device=device) + mask = torch.rand(batch_size, device=device) < 0.2 + state_indices[mask] = -1 + + # Map valid indices to random slots in pool + valid_mask = state_indices >= 0 + num_valid = valid_mask.sum().item() + if num_valid > 0: + valid_slots = torch.randperm(pool_size, device=device)[:num_valid].to( + torch.int32 + ) + state_indices[valid_mask] = valid_slots + + _verify_pooled_decode_against_reference( + batch_size=batch_size, + pool_size=pool_size, + num_heads=num_heads, + head_dim=head_dim, + state_indices=state_indices, + dtype_torch=dtype_torch, + seed=seed, + ) + + +# ============================================================================ +# Test 2: sglang-style pooled decode pattern +# Simulates exactly how sglang's GDNAttnBackend.forward_decode calls the kernel: +# - Full pool passed as state (pool_size+1 slots, slot 0 is sentinel) +# - cache_indices from scheduler, with PAD_SLOT_ID = -1 for padding +# - .to(torch.int32) cast on cache_indices +# ============================================================================ + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("batch_size", [1, 4, 8, 16, 32]) +@pytest.mark.parametrize("pool_size", [128, 256]) +def test_decode_pooled_sglang_pattern(dtype, batch_size, pool_size, seed=42): + """ + Simulate sglang's forward_decode calling pattern: + - ssm_states is the full pool [pool_size+1, HV, V, K] (extra sentinel slot) + - cache_indices are int64, cast to int32 at call site + - PAD_SLOT_ID = -1 for CUDA graph padding slots + - The kernel should NOT gather/scatter — zero-copy pool access + """ + _skip_if_not_sm90_or_later() + + device = torch.device("cuda") + dtype_torch = getattr(torch, dtype) + num_heads = 16 + head_dim = 128 + + PAD_SLOT_ID = -1 + total_pool_size = pool_size + 1 # sglang adds +1 sentinel slot + + # sglang scheduler produces int64 cache_indices — each request has a unique slot + num_valid = batch_size - max(1, batch_size // 4) + cache_indices_int64 = torch.randperm(pool_size, device=device)[:num_valid].to( + torch.int64 + ) + + # Simulate CUDA graph padding: remaining slots are PAD_SLOT_ID + num_padded = batch_size - num_valid + padding = torch.full((num_padded,), PAD_SLOT_ID, dtype=torch.int64, device=device) + cache_indices_int64 = torch.cat([cache_indices_int64, padding]) + + # sglang casts to int32 at call site + state_indices = cache_indices_int64.to(torch.int32) + + _verify_pooled_decode_against_reference( + batch_size=batch_size, + pool_size=total_pool_size, + num_heads=num_heads, + head_dim=head_dim, + state_indices=state_indices, + dtype_torch=dtype_torch, + seed=seed, + ) + + +# ============================================================================ +# Test 3: Pooled vs non-pooled equivalence +# When state_indices is identity [0, 1, ..., B-1] and pool_size == B, +# pooled decode should produce identical results to non-pooled decode. +# ============================================================================ + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) +def test_decode_pooled_vs_nonpooled_equivalence(dtype, batch_size, seed=42): + """ + When state_indices = [0, 1, ..., B-1] (identity mapping) and pool_size == B, + the pooled kernel should produce identical results to the non-pooled kernel. + This verifies zero-copy mode doesn't introduce numerical differences. + """ + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + device = torch.device("cuda") + dtype_torch = getattr(torch, dtype) + num_heads = 16 + head_dim = 128 + + # Identical inputs + q = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + k = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + v = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + + A_log = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 + b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 + + # State: [B, HV, V, K] — same for both + state_base = torch.randn( + batch_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device + ) + + # Run non-pooled (state_indices=None) + state_nonpooled = state_base.clone() + output_nonpooled, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=state_nonpooled, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + use_qk_l2norm=True, + state_indices=None, + ) + torch.cuda.synchronize() + + # Run pooled with identity indices + state_pooled = state_base.clone() + identity_indices = torch.arange(batch_size, dtype=torch.int32, device=device) + output_pooled, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=state_pooled, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + use_qk_l2norm=True, + state_indices=identity_indices, + ) + torch.cuda.synchronize() + + # Outputs should be identical (same compiled code path, same data) + torch.testing.assert_close( + output_pooled, + output_nonpooled, + atol=1e-5, + rtol=1e-5, + msg="Pooled decode with identity indices should match non-pooled decode", + ) + + # State should be identical + torch.testing.assert_close( + state_pooled, + state_nonpooled, + atol=1e-5, + rtol=1e-5, + msg="Pooled state update with identity indices should match non-pooled", + ) + + +# ============================================================================ +# Test 4: All-padding batch (all negative indices) +# Output should be all zeros, pool state should be completely unchanged. +# ============================================================================ + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) +def test_decode_pooled_all_padding(dtype, batch_size, seed=42): + """ + When ALL state_indices are negative (entire batch is padding), + output must be all zeros and no pool state should be modified. + This happens in CUDA graph when batch_size < max_bs. + """ + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + device = torch.device("cuda") + dtype_torch = getattr(torch, dtype) + num_heads = 16 + head_dim = 128 + pool_size = batch_size * 2 + + q = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + k = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + v = torch.randn( + batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device + ) + + A_log = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 + b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 + + state_pool = torch.randn( + pool_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device + ) + initial_state_pool = state_pool.clone() + + # ALL negative indices + state_indices = torch.full((batch_size,), -1, dtype=torch.int32, device=device) + + output, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=state_pool, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + state_indices=state_indices, + use_qk_l2norm=True, + ) + torch.cuda.synchronize() + + # All output must be zero + assert torch.all(output == 0), ( + f"All-padding batch should produce all-zero output, " + f"but got max abs = {output.abs().max().item()}" + ) + + # Entire pool must be unchanged + torch.testing.assert_close( + state_pool, + initial_state_pool, + msg="All-padding batch should not modify any pool state", + ) + + +if __name__ == "__main__": + print("Running pooled decode tests...") + + print("\n=== Test 1: Negative indices ===") + test_decode_pooled_with_negative_indices("bfloat16", 32, 2) + print("PASS") + + print("\n=== Test 2: sglang pattern ===") + test_decode_pooled_sglang_pattern("bfloat16", 16, 128) + print("PASS") + + print("\n=== Test 3: Pooled vs non-pooled equivalence ===") + test_decode_pooled_vs_nonpooled_equivalence("bfloat16", 16) + print("PASS") + + print("\n=== Test 4: All padding ===") + test_decode_pooled_all_padding("bfloat16", 16) + print("PASS") + + print("\n✅ All pooled decode tests passed!") + print("\nTo run full test suite:") + print(" pytest tests/gdn/test_decode_pooled.py -v") From 796aacf9c5216ea406b14a6ca9c791bc805a60ae Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Wed, 4 Mar 2026 15:20:30 +0800 Subject: [PATCH 6/9] fix(gdn): remove pool_size from decode kernel cache key (AI-assisted) pool_size only affects tensor shapes at runtime (h0_source reshape), not compiled kernel code. Removing it from the cache key avoids unnecessary recompilation when pool_size changes. Also translate Chinese comments to English per review feedback. --- flashinfer/gdn_decode.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 72d7d24a0f..05a086690b 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -247,7 +247,7 @@ def gdn_decode_kernel_small_batch_pretranspose( gSrc_batch = h0_source[(state_idx, None, None)] # (V, K) gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (state_idx, None, 0)) - # V 方向分 tiles + # Tile along V dimension gSrc = cute.local_tile( gSrc_batch, (TILE_V, TILE_K), (None, 0) ) # (TILE_V, TILE_K, num_v_tiles) @@ -536,7 +536,7 @@ def gdn_decode_kernel_big_batch_pretranspose( gSrc_batch = h0_source[(state_idx, None, None)] # (V, K) gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (state_idx, None, 0)) - # V 方向分 tiles + # Tile along V dimension gSrc = cute.local_tile( gSrc_batch, (TILE_V, TILE_K), (None, 0) ) # (TILE_V, TILE_K, num_v_tiles) @@ -948,7 +948,6 @@ def _get_compiled_decode_kernel( HV: int, K: int, V: int, - pool_size: int, dtype: torch.dtype, scale: float, use_qk_l2norm: bool, @@ -959,8 +958,9 @@ def _get_compiled_decode_kernel( When ``use_pool_indexing=True``, the kernel reads/writes state from a shared pool using ``state_indices``. Because ``use_pool_indexing`` is a ``cutlass.Constexpr``, the two modes produce different compiled CUDA code and - must have separate cache entries (ensured by including ``pool_size`` and - ``use_pool_indexing`` in the key). + must have separate cache entries (ensured by including + ``use_pool_indexing`` in the key). ``pool_size`` is a runtime parameter + (only affects tensor shapes, not compiled code) and is intentionally excluded. """ # This will be populated on first call return {} @@ -1142,7 +1142,6 @@ def gated_delta_rule_decode_pretranspose( HV, K, V, - pool_size, q.dtype, scale, use_qk_l2norm, From 308ad6b26d16d7d7400a994d6e90eed4566deeca Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Wed, 4 Mar 2026 16:36:09 +0800 Subject: [PATCH 7/9] feat(gdn): use sym_int for runtime pool_size in decode kernel (AI-assisted) Replace from_dlpack(h0_source) with make_fake_compact_tensor using cute.sym_int() for the pool_batch dimension, so a single compiled kernel handles any pool_size at runtime. stride_order=(2,1,0) ensures row-major layout with concrete strides for cp.async alignment. Benchmarks show zero performance regression vs compile-time shape: from_dlpack: 0.0306ms median (bs=32, pool=128) sym_int: 0.0307ms median (bs=32, pool=128) --- flashinfer/gdn_decode.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 05a086690b..77d19a03b0 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -1164,8 +1164,19 @@ def gated_delta_rule_decode_pretranspose( if "compiled" not in cache: stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Convert tensors to CuTe format for compilation only - h0_source_tensor = from_dlpack(h0_source, assumed_align=16) + # Convert tensors to CuTe format for compilation only. + # h0_source uses symbolic first dim so the same compiled kernel + # works for any pool_size (pool_size only affects this dimension). + # stride_order=(2,1,0) = row-major: dim K (stride 1) is innermost, + # giving concrete strides (V*K, K, 1) so the compiler can verify + # 128-bit alignment for cp.async copy atoms. + sym_pool_batch = cute.sym_int() + h0_source_tensor = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + (sym_pool_batch, V, K), + stride_order=(2, 1, 0), + assumed_align=16, + ) A_log_tensor = from_dlpack(A_log, assumed_align=16) a_tensor = from_dlpack(a, assumed_align=16) dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16) From 7bfc18faa8aa29f4dcc534a9b67df5cb84f97a4a Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Thu, 5 Mar 2026 17:44:22 +0800 Subject: [PATCH 8/9] feat: enable f32 pool+indices support in legacy CuTe DSL kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove 'assert not use_pool' from f32 path — the sym_int approach already handles arbitrary pool_size at runtime with zero overhead. Tests 1/2/4 use f32 state with negative indices (padding support). Test 3 uses bf16 state (routed to bf16 fast path). All 23 pooled decode tests pass. AI-assisted. --- flashinfer/gdn_decode.py | 40 ++++++++++++++++++++------------- tests/gdn/test_decode_pooled.py | 9 +++----- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index a38dc11b3a..acb6b2aa2b 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -1111,13 +1111,16 @@ def gated_delta_rule_decode_pretranspose( return_state = initial_state if use_pool else state return output, return_state - # Legacy path: T=1 only, float32 state (no pool+indices support) - assert not use_pool, ( - "pool+indices (initial_state/initial_state_indices) requires bfloat16 state " - "with T in 1..4 and K=V=128 (the gdn_decode_klast_bf16_state fast path)" - ) + # Legacy path: T=1 only, float32 state assert T == 1, f"Decode only supports T=1, got T={T}" - assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" + + if use_pool: + assert initial_state.dtype == torch.float32, ( + f"initial_state must be float32 for legacy path, got {initial_state.dtype}" + ) + else: + assert state is not None, "Either state or initial_state must be provided" + assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" # Validate K and V constraints assert K >= 128, f"K must be at least 128, got K={K}" @@ -1142,12 +1145,17 @@ def gated_delta_rule_decode_pretranspose( target_dtype = output.dtype if output_provided else q.dtype if output is None: - # Kernel outputs bfloat16, allocate in that dtype first output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device) - # Convert state from [B, HV, V, K] to [B*HV, V, K] for kernel - pool_size = B - h0_source = state.reshape(pool_size * HV, V, K) + # Build h0_source: [pool_size*HV, V, K] for kernel + if use_pool: + pool_size = initial_state.shape[0] + h0_source = initial_state.reshape(pool_size * HV, V, K) + return_state = initial_state + else: + pool_size = B + h0_source = state.reshape(pool_size * HV, V, K) + return_state = state # Compile kernel with TVM FFI (cached) cache_key = ( @@ -1166,7 +1174,11 @@ def gated_delta_rule_decode_pretranspose( if "h0_indices" not in cache or cache["h0_indices"].device != q.device: cache["h0_indices"] = torch.zeros(B, dtype=torch.int32, device=q.device) - h0_indices = cache["h0_indices"] + + if use_pool: + h0_indices = initial_state_indices.to(torch.int32) + else: + h0_indices = cache["h0_indices"] if "cu_seqlens" not in cache or cache["cu_seqlens"].device != q.device: cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) @@ -1246,11 +1258,9 @@ def gated_delta_rule_decode_pretranspose( if output.dtype != target_dtype: output = output.to(target_dtype) - # Copy state back only if state was not contiguous - # (if contiguous, reshape returns a view and kernel updated state in-place) - if not state.is_contiguous(): + if not use_pool and not state.is_contiguous(): state.copy_(h0_source.reshape(B, HV, V, K)) - return output, state + return output, return_state # ============================================================================ diff --git a/tests/gdn/test_decode_pooled.py b/tests/gdn/test_decode_pooled.py index 0a9870771a..6224aee456 100644 --- a/tests/gdn/test_decode_pooled.py +++ b/tests/gdn/test_decode_pooled.py @@ -80,9 +80,9 @@ def _verify_pooled_decode_against_reference( a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - # State pool: [pool_size, HV, V, K] (K-last layout, bfloat16 for bf16 fast path) + # State pool: [pool_size, HV, V, K] (K-last layout, float32 for CuTe DSL kernel) state_pool = torch.randn( - pool_size, num_heads, head_dim, head_dim, dtype=torch.bfloat16, device=device + pool_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device ) initial_state_pool = state_pool.clone() @@ -181,7 +181,6 @@ def _verify_pooled_decode_against_reference( # ============================================================================ -@pytest.mark.skip(reason="bf16 fast path kernel does not support negative indices yet") @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) @pytest.mark.parametrize("pool_size_multiplier", [2]) @@ -237,7 +236,6 @@ def test_decode_pooled_with_negative_indices( # ============================================================================ -@pytest.mark.skip(reason="bf16 fast path kernel does not support negative indices yet") @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 8, 16, 32]) @pytest.mark.parametrize("pool_size", [128, 256]) @@ -387,7 +385,6 @@ def test_decode_pooled_vs_nonpooled_equivalence(dtype, batch_size, seed=42): # ============================================================================ -@pytest.mark.skip(reason="bf16 fast path kernel does not support negative indices yet") @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) def test_decode_pooled_all_padding(dtype, batch_size, seed=42): @@ -424,7 +421,7 @@ def test_decode_pooled_all_padding(dtype, batch_size, seed=42): b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 state_pool = torch.randn( - pool_size, num_heads, head_dim, head_dim, dtype=torch.bfloat16, device=device + pool_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device ) initial_state_pool = state_pool.clone() From e5df67cdc0826a78c2bc0a1db7c730b3379a0a2e Mon Sep 17 00:00:00 2001 From: Xuting Zhou Date: Fri, 6 Mar 2026 17:53:19 +0800 Subject: [PATCH 9/9] test: consolidate pool decode tests into test_decode_delta_rule.py Merge test_decode_pooled.py into test_decode_delta_rule.py with: - state_dtype parametrize (bf16 + f32) for pool test - negative indices and all-padding tests (f32 state only) - per-sample Python reference to avoid JIT cache contamination - float32 dt_bias matching SGLang production usage - pytestmark skip preserved to match upstream main CI --- tests/gdn/test_decode_delta_rule.py | 314 +++++++++++++++++- tests/gdn/test_decode_pooled.py | 480 ---------------------------- 2 files changed, 311 insertions(+), 483 deletions(-) delete mode 100644 tests/gdn/test_decode_pooled.py diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 7987a6ad01..3f5183abb5 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -116,7 +116,7 @@ def _test_decode_kernel_pretranspose( A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # dt_bias: decay bias [HV] - dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # a: input-dependent decay [B, 1, HV] # Convert alpha to a: alpha = exp(-exp(A_log) * softplus(a + dt_bias)) @@ -284,7 +284,7 @@ def _test_decode_kernel_nontranspose( A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # dt_bias: decay bias [HV] - dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # a: input-dependent decay [B, 1, HV] a = ( @@ -440,7 +440,7 @@ def _test_decode_kernel_pretranspose_pool( v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 - dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) @@ -508,6 +508,7 @@ def _test_decode_kernel_pretranspose_pool( ) +@pytest.mark.parametrize("state_dtype", ["bfloat16", "float32"]) @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) @@ -521,9 +522,292 @@ def test_decode_kernel_pretranspose_pool( head_size: int, batch_size: int, scale: float, + state_dtype: str, seed: int = int(os.environ.get("SEED", "0")), ): _test_decode_kernel_pretranspose_pool( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + scale, + state_dtype=state_dtype, + seed=seed, + ) + + +# ============================================================================ +# Test pretranspose kernel pool + indices with negative indices (padding) +# +# Negative pool indices signal padding slots: the kernel must write zeros to +# output for those batch elements and leave the pool untouched. The gather → +# direct-state reference cannot handle negative indices, so we compare valid +# slots against the Python reference and verify padding semantics directly. +# +# Only float32 state is tested because the bf16 fast-path kernel does not +# support negative indices. +# ============================================================================ + + +def _test_decode_kernel_pretranspose_pool_negative_indices( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + pool_multiplier: int = 2, + padding_fraction: float = 0.2, + seed: int | None = None, +): + """Pool+indices with negative indices must zero output for padding slots + and match the Python reference for valid slots.""" + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + pool_size = batch_size * pool_multiplier + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.nn.functional.normalize( + torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch), + p=2.0, + dim=-1, + ) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + + # Float32 state pool (only f32 CuTe DSL kernel supports negative indices) + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=torch.float32 + ) + + # Build indices with ~padding_fraction padding slots + indices = torch.arange(batch_size, dtype=torch.int32, device=device) + mask = torch.rand(batch_size, device=device) < padding_fraction + # Ensure at least one valid and one padding slot when batch_size >= 2 + if batch_size >= 2: + mask[0] = False # first slot valid + mask[-1] = True # last slot padding + indices[mask] = -1 + + # Map valid indices to random non-contiguous pool slots + valid_mask = indices >= 0 + num_valid = valid_mask.sum().item() + if num_valid > 0: + valid_slots = torch.randperm(pool_size, device=device)[:num_valid].to( + torch.int32 + ) + indices[valid_mask] = valid_slots + + pool_under_test = pool.clone() + out_pool, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=indices, + ) + torch.cuda.synchronize() + + # ── Padding slots must produce zero output ──────────────────────────────── + invalid_mask = indices < 0 + if invalid_mask.any(): + padded_output = out_pool[invalid_mask] + assert torch.all(padded_output == 0), ( + f"Padding slots must produce zero output, " + f"but got max abs = {padded_output.abs().max().item()}" + ) + + # ── Valid slots: compare per-sample against Python reference ────────────── + valid_indices_local = torch.where(valid_mask)[0].cpu().numpy() + pool_snapshot = pool.clone() # original pool before kernel + + for i in valid_indices_local: + pool_idx = indices[i].item() + ref_o, ref_s = decode_delta_rule( + q[i].squeeze(0).unsqueeze(0).float(), # [1, H, K] + k[i].squeeze(0).unsqueeze(0).float(), + v[i].squeeze(0).unsqueeze(0).float(), + pool_snapshot[pool_idx] + .float() + .transpose(-2, -1) + .contiguous() + .unsqueeze(0), # [1, HV, K, V] + A_log=A_log, + a=a[i].squeeze(0).unsqueeze(0), # [1, HV] + dt_bias=dt_bias.float(), + b=b[i].squeeze(0).unsqueeze(0), + scale_factor=scale, + use_l2_norm=True, + ) + # Output + torch.testing.assert_close( + out_pool[i].float().squeeze(0), + ref_o.squeeze(0).to(device), + atol=1e-2, + rtol=1e-2, + ) + # State update (kernel: [HV, V, K], ref: [1, HV, K, V]) + torch.testing.assert_close( + pool_under_test[pool_idx].float(), + ref_s.squeeze(0).transpose(-2, -1).to(device), + atol=1e-2, + rtol=1e-2, + ) + + # ── Untouched pool slots must remain unchanged ──────────────────────────── + used_pool_indices = indices[valid_mask].unique() + touched = torch.zeros(pool_size, dtype=torch.bool, device=device) + if len(used_pool_indices) > 0: + touched[used_pool_indices.long()] = True + torch.testing.assert_close( + pool_under_test[~touched], pool[~touched], atol=0.0, rtol=0.0 + ) + + print( + f"✓ Pool+indices negative-index test passed " + f"(batch={batch_size}, pool={pool_size}, dtype={dtype})" + ) + + +def _test_decode_kernel_pretranspose_pool_all_padding( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + pool_multiplier: int = 2, + seed: int | None = None, +): + """When ALL indices are negative (entire batch is padding), output must be + all zeros and the pool must remain completely unchanged.""" + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + pool_size = batch_size * pool_multiplier + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=torch.float32 + ) + indices = torch.full((batch_size,), -1, dtype=torch.int32, device=device) + + pool_under_test = pool.clone() + out, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=indices, + ) + torch.cuda.synchronize() + + assert torch.all(out == 0), ( + f"All-padding batch must produce zero output, " + f"but got max abs = {out.abs().max().item()}" + ) + torch.testing.assert_close( + pool_under_test, + pool, + atol=0.0, + rtol=0.0, + msg="All-padding batch must not modify any pool state", + ) + + print( + f"✓ Pool+indices all-padding test passed " + f"(batch={batch_size}, pool={pool_size}, dtype={dtype})" + ) + + +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_decode_kernel_pretranspose_pool_negative_indices( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + scale: float, + seed: int = int(os.environ.get("SEED", "0")), +): + _test_decode_kernel_pretranspose_pool_negative_indices( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + scale, + seed=seed, + ) + + +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_decode_kernel_pretranspose_pool_all_padding( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + scale: float, + seed: int = int(os.environ.get("SEED", "0")), +): + _test_decode_kernel_pretranspose_pool_all_padding( dtype, batch_size, num_q_heads, @@ -1105,6 +1389,30 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( seed=42, ) + print("\n=== Testing Pool+indices with negative indices ===") + _test_decode_kernel_pretranspose_pool_negative_indices( + dtype="bfloat16", + batch_size=8, + num_q_heads=16, + num_k_heads=16, + num_v_heads=32, + head_size=128, + scale=1.0, + seed=42, + ) + + print("\n=== Testing Pool+indices all-padding ===") + _test_decode_kernel_pretranspose_pool_all_padding( + dtype="bfloat16", + batch_size=8, + num_q_heads=16, + num_k_heads=16, + num_v_heads=32, + head_size=128, + scale=1.0, + seed=42, + ) + print("\n=== Testing IMPROVED CuTe-DSL version (T=1,2,3,4) ===") if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: for t in [1, 2, 3, 4]: diff --git a/tests/gdn/test_decode_pooled.py b/tests/gdn/test_decode_pooled.py deleted file mode 100644 index 6224aee456..0000000000 --- a/tests/gdn/test_decode_pooled.py +++ /dev/null @@ -1,480 +0,0 @@ -""" -Copyright (c) 2025 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import pytest -import torch -import math -import random - -from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose -from flashinfer.utils import get_compute_capability - -try: - from .reference_delta_rule import decode_delta_rule -except ImportError: - import sys - from pathlib import Path - - sys.path.insert(0, str(Path(__file__).parent)) - from reference_delta_rule import decode_delta_rule - - -def _skip_if_not_sm90_or_later(): - """Skip test if not Hopper (SM90+) or Blackwell (SM100+) architecture.""" - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - cc = get_compute_capability(torch.device("cuda")) - if cc[0] not in [9, 10, 11, 12]: - pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}") - - -def _verify_pooled_decode_against_reference( - batch_size: int, - pool_size: int, - num_heads: int, - head_dim: int, - state_indices: torch.Tensor, - dtype_torch: torch.dtype, - seed: int = 42, -): - """ - Core verification logic: run pooled decode kernel and compare per-sample - against the reference implementation. - - Returns (output, state_pool, initial_state_pool) for further assertions. - """ - device = torch.device("cuda") - - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - # Inputs: q, k, v [B, 1, H, D] - q = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - k = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - v = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - k = torch.nn.functional.normalize(k, p=2.0, dim=-1) - - # GDN params - A_log = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 - dt_bias = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 - a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - - # State pool: [pool_size, HV, V, K] (K-last layout, float32 for CuTe DSL kernel) - state_pool = torch.randn( - pool_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device - ) - initial_state_pool = state_pool.clone() - - # Run kernel (pool+indices path: state=None, initial_state=pool, initial_state_indices=indices) - output, _ = gated_delta_rule_decode_pretranspose( - q=q, - k=k, - v=v, - state=None, - A_log=A_log, - a=a, - dt_bias=dt_bias, - b=b, - initial_state=state_pool, - initial_state_indices=state_indices, - use_qk_l2norm=True, - ) - torch.cuda.synchronize() - - # Verify per-sample against reference - valid_mask = state_indices >= 0 - invalid_mask = state_indices < 0 - - # Case A: Padding slots (index < 0) must produce zero output - if invalid_mask.any(): - padded_output = output[invalid_mask] - assert torch.all(padded_output == 0), "Padding slots must produce zero output" - - # Case B: Valid slots — compare against reference per sample - valid_indices_local = torch.where(valid_mask)[0].cpu().numpy() - - for i in valid_indices_local: - pool_idx = state_indices[i].item() - - q_i = q[i].float() # [1, H, D] - k_i = k[i].float() - v_i = v[i].float() - a_i = a[i].float() # [1, H] - b_i = b[i].float() - - init_s_i = ( - initial_state_pool[pool_idx] - .float() - .transpose(-2, -1) - .contiguous() - .unsqueeze(0) - ) - - ref_o, ref_s = decode_delta_rule( - q_i, - k_i, - v_i, - init_s_i, - A_log=A_log, - a=a_i, - dt_bias=dt_bias, - b=b_i, - scale_factor=1.0 / math.sqrt(head_dim), - use_l2_norm=True, - ) - - # Verify output - out_i = output[i].float() - torch.testing.assert_close( - out_i.squeeze(0), ref_o.squeeze(0).to(out_i.device), atol=1e-2, rtol=1e-2 - ) - - # Verify state update (kernel: [H, V, K], ref: [1, H, K, V]) - ref_s_transposed = ref_s.squeeze(0).transpose(-2, -1) - current_pool_state = state_pool[pool_idx].float() - torch.testing.assert_close( - current_pool_state, ref_s_transposed.to(device), atol=1e-2, rtol=1e-2 - ) - - # Case C: Untouched pool slots must remain unchanged - used_indices = state_indices[valid_mask].unique() - touched_mask = torch.zeros(pool_size, dtype=torch.bool, device=device) - if len(used_indices) > 0: - touched_mask[used_indices.long()] = True - - untouched_states_final = state_pool[~touched_mask] - untouched_states_initial = initial_state_pool[~touched_mask] - - if len(untouched_states_final) > 0: - torch.testing.assert_close( - untouched_states_final, - untouched_states_initial, - msg="Untouched pool states should not change", - ) - - return output, state_pool, initial_state_pool - - -# ============================================================================ -# Test 1: Basic pooled decode with negative indices -# ============================================================================ - - -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) -@pytest.mark.parametrize("pool_size_multiplier", [2]) -def test_decode_pooled_with_negative_indices( - dtype, batch_size, pool_size_multiplier, seed=42 -): - """ - Test pooled decode with state_indices, including negative indices for padding. - 20% of batch elements are randomly masked as padding (index = -1). - Valid indices are randomly scattered across the pool. - """ - _skip_if_not_sm90_or_later() - - device = torch.device("cuda") - dtype_torch = getattr(torch, dtype) - num_heads = 16 - head_dim = 128 - pool_size = batch_size * pool_size_multiplier - - # Create indices with ~20% padding - random.seed(seed) - torch.random.manual_seed(seed) - state_indices = torch.arange(batch_size, dtype=torch.int32, device=device) - mask = torch.rand(batch_size, device=device) < 0.2 - state_indices[mask] = -1 - - # Map valid indices to random slots in pool - valid_mask = state_indices >= 0 - num_valid = valid_mask.sum().item() - if num_valid > 0: - valid_slots = torch.randperm(pool_size, device=device)[:num_valid].to( - torch.int32 - ) - state_indices[valid_mask] = valid_slots - - _verify_pooled_decode_against_reference( - batch_size=batch_size, - pool_size=pool_size, - num_heads=num_heads, - head_dim=head_dim, - state_indices=state_indices, - dtype_torch=dtype_torch, - seed=seed, - ) - - -# ============================================================================ -# Test 2: sglang-style pooled decode pattern -# Simulates exactly how sglang's GDNAttnBackend.forward_decode calls the kernel: -# - Full pool passed as state (pool_size+1 slots, slot 0 is sentinel) -# - cache_indices from scheduler, with PAD_SLOT_ID = -1 for padding -# - .to(torch.int32) cast on cache_indices -# ============================================================================ - - -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("batch_size", [1, 4, 8, 16, 32]) -@pytest.mark.parametrize("pool_size", [128, 256]) -def test_decode_pooled_sglang_pattern(dtype, batch_size, pool_size, seed=42): - """ - Simulate sglang's forward_decode calling pattern: - - ssm_states is the full pool [pool_size+1, HV, V, K] (extra sentinel slot) - - cache_indices are int64, cast to int32 at call site - - PAD_SLOT_ID = -1 for CUDA graph padding slots - - The kernel should NOT gather/scatter — zero-copy pool access - """ - _skip_if_not_sm90_or_later() - - device = torch.device("cuda") - dtype_torch = getattr(torch, dtype) - num_heads = 16 - head_dim = 128 - - PAD_SLOT_ID = -1 - total_pool_size = pool_size + 1 # sglang adds +1 sentinel slot - - # sglang scheduler produces int64 cache_indices — each request has a unique slot - num_valid = batch_size - max(1, batch_size // 4) - cache_indices_int64 = torch.randperm(pool_size, device=device)[:num_valid].to( - torch.int64 - ) - - # Simulate CUDA graph padding: remaining slots are PAD_SLOT_ID - num_padded = batch_size - num_valid - padding = torch.full((num_padded,), PAD_SLOT_ID, dtype=torch.int64, device=device) - cache_indices_int64 = torch.cat([cache_indices_int64, padding]) - - # sglang casts to int32 at call site - state_indices = cache_indices_int64.to(torch.int32) - - _verify_pooled_decode_against_reference( - batch_size=batch_size, - pool_size=total_pool_size, - num_heads=num_heads, - head_dim=head_dim, - state_indices=state_indices, - dtype_torch=dtype_torch, - seed=seed, - ) - - -# ============================================================================ -# Test 3: Pooled vs non-pooled equivalence -# When state_indices is identity [0, 1, ..., B-1] and pool_size == B, -# pooled decode should produce identical results to non-pooled decode. -# ============================================================================ - - -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) -def test_decode_pooled_vs_nonpooled_equivalence(dtype, batch_size, seed=42): - """ - When state_indices = [0, 1, ..., B-1] (identity mapping) and pool_size == B, - the pooled kernel should produce identical results to the non-pooled kernel. - This verifies zero-copy mode doesn't introduce numerical differences. - """ - _skip_if_not_sm90_or_later() - - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - device = torch.device("cuda") - dtype_torch = getattr(torch, dtype) - num_heads = 16 - head_dim = 128 - - # Identical inputs - q = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - k = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - v = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - k = torch.nn.functional.normalize(k, p=2.0, dim=-1) - - A_log = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 - dt_bias = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 - a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - - state_base = torch.randn( - batch_size, num_heads, head_dim, head_dim, dtype=torch.bfloat16, device=device - ) - - state_nonpooled = state_base.clone() - output_nonpooled, _ = gated_delta_rule_decode_pretranspose( - q=q, - k=k, - v=v, - state=state_nonpooled, - A_log=A_log, - a=a, - dt_bias=dt_bias, - b=b, - use_qk_l2norm=True, - ) - torch.cuda.synchronize() - - state_pooled = state_base.clone() - identity_indices = torch.arange(batch_size, dtype=torch.int32, device=device) - output_pooled, _ = gated_delta_rule_decode_pretranspose( - q=q, - k=k, - v=v, - state=None, - A_log=A_log, - a=a, - dt_bias=dt_bias, - b=b, - use_qk_l2norm=True, - initial_state=state_pooled, - initial_state_indices=identity_indices, - ) - torch.cuda.synchronize() - - # Outputs should be identical (same compiled code path, same data) - torch.testing.assert_close( - output_pooled, - output_nonpooled, - atol=1e-5, - rtol=1e-5, - msg="Pooled decode with identity indices should match non-pooled decode", - ) - - # State should be identical - torch.testing.assert_close( - state_pooled, - state_nonpooled, - atol=1e-5, - rtol=1e-5, - msg="Pooled state update with identity indices should match non-pooled", - ) - - -# ============================================================================ -# Test 4: All-padding batch (all negative indices) -# Output should be all zeros, pool state should be completely unchanged. -# ============================================================================ - - -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) -def test_decode_pooled_all_padding(dtype, batch_size, seed=42): - """ - When ALL state_indices are negative (entire batch is padding), - output must be all zeros and no pool state should be modified. - This happens in CUDA graph when batch_size < max_bs. - """ - _skip_if_not_sm90_or_later() - - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - device = torch.device("cuda") - dtype_torch = getattr(torch, dtype) - num_heads = 16 - head_dim = 128 - pool_size = batch_size * 2 - - q = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - k = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - v = torch.randn( - batch_size, 1, num_heads, head_dim, dtype=dtype_torch, device=device - ) - - A_log = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 - dt_bias = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 - a = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - b = torch.randn(batch_size, 1, num_heads, dtype=dtype_torch, device=device) * 0.1 - - state_pool = torch.randn( - pool_size, num_heads, head_dim, head_dim, dtype=torch.float32, device=device - ) - initial_state_pool = state_pool.clone() - - state_indices = torch.full((batch_size,), -1, dtype=torch.int32, device=device) - - output, _ = gated_delta_rule_decode_pretranspose( - q=q, - k=k, - v=v, - state=None, - A_log=A_log, - a=a, - dt_bias=dt_bias, - b=b, - initial_state=state_pool, - initial_state_indices=state_indices, - use_qk_l2norm=True, - ) - torch.cuda.synchronize() - - # All output must be zero - assert torch.all(output == 0), ( - f"All-padding batch should produce all-zero output, " - f"but got max abs = {output.abs().max().item()}" - ) - - # Entire pool must be unchanged - torch.testing.assert_close( - state_pool, - initial_state_pool, - msg="All-padding batch should not modify any pool state", - ) - - -if __name__ == "__main__": - print("Running pooled decode tests...") - - print("\n=== Test 1: Negative indices ===") - test_decode_pooled_with_negative_indices("bfloat16", 32, 2) - print("PASS") - - print("\n=== Test 2: sglang pattern ===") - test_decode_pooled_sglang_pattern("bfloat16", 16, 128) - print("PASS") - - print("\n=== Test 3: Pooled vs non-pooled equivalence ===") - test_decode_pooled_vs_nonpooled_equivalence("bfloat16", 16) - print("PASS") - - print("\n=== Test 4: All padding ===") - test_decode_pooled_all_padding("bfloat16", 16) - print("PASS") - - print("\n✅ All pooled decode tests passed!") - print("\nTo run full test suite:") - print(" pytest tests/gdn/test_decode_pooled.py -v")