diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 9257f8c0dc..80a7039faf 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -172,8 +172,9 @@ def gated_delta_rule_decode_pretranspose( - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used (supports both the direct ``state`` path and the pool+indices path). - - pool+indices (``initial_state``/``initial_state_indices``) only supported - via the bf16 fast path; float32 state raises an error. + - pool+indices (``initial_state``/``initial_state_indices``) supported on + both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path + (T=1). The float32 path also supports negative indices for padding. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes @@ -239,13 +240,17 @@ def gated_delta_rule_decode_pretranspose( return_state = initial_state if use_pool else state return output, return_state - # Legacy path: T=1 only, float32 state (no pool+indices support) - assert not use_pool, ( - "pool+indices (initial_state/initial_state_indices) requires bfloat16 state " - "with T in 1..4 and K=V=128 (the gdn_decode_klast_bf16_state fast path)" - ) + # Legacy path: T=1 only, float32 state (supports pool+indices via CuTe DSL kernel) + use_pool_indexing = initial_state_indices is not None assert T == 1, f"Decode only supports T=1, got T={T}" - assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" + + if use_pool: + assert initial_state.dtype == torch.float32, ( + f"initial_state must be float32 for legacy path, got {initial_state.dtype}" + ) + else: + assert state is not None, "Either state or initial_state must be provided" + assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" # Validate K and V constraints assert K >= 128, f"K must be at least 128, got K={K}" @@ -273,8 +278,18 @@ def gated_delta_rule_decode_pretranspose( # Kernel outputs bfloat16, allocate in that dtype first output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device) - # Convert state from [B, HV, V, K] to [B*HV, V, K] for kernel - h0_source = state.reshape(B * HV, V, K) + # Build h0_source: [pool_size*HV, V, K] for kernel + if use_pool: + pool_size = initial_state.shape[0] + assert initial_state.is_contiguous(), ( + "initial_state (pool) must be contiguous for correct kernel pointer arithmetic" + ) + h0_source = initial_state.reshape(pool_size * HV, V, K) + return_state = initial_state + else: + pool_size = B + h0_source = state.reshape(pool_size * HV, V, K) + return_state = state # Execute kernel run_pretranspose_decode( @@ -295,18 +310,21 @@ def gated_delta_rule_decode_pretranspose( V, scale, use_qk_l2norm, + use_pool_indexing=use_pool_indexing, + initial_state_indices=initial_state_indices, ) - # Copy state back only if state was not contiguous + # Copy state back only if not using pool and state was not contiguous # (if contiguous, reshape returns a view and kernel updated state in-place) - if not state.is_contiguous(): + # Pool path: kernel writes directly into initial_state via pool indices + if not use_pool and not state.is_contiguous(): state.copy_(h0_source.reshape(B, HV, V, K)) # Convert output to target dtype if needed (kernel outputs bfloat16) if output.dtype != target_dtype: output = output.to(target_dtype) - return output, state + return output, return_state # ============================================================================ diff --git a/flashinfer/gdn_kernels/gdn_decode_pretranspose.py b/flashinfer/gdn_kernels/gdn_decode_pretranspose.py index b0f1963056..a6d0f1bc6e 100644 --- a/flashinfer/gdn_kernels/gdn_decode_pretranspose.py +++ b/flashinfer/gdn_kernels/gdn_decode_pretranspose.py @@ -21,6 +21,7 @@ """ import functools +from typing import Optional import torch import cutlass @@ -68,6 +69,7 @@ def gdn_decode_kernel_small_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], + use_pool_indexing: cutlass.Constexpr[bool] = False, ): """Each block uses pipeline to load one batch and vectorized writeback""" @@ -129,189 +131,203 @@ def gdn_decode_kernel_small_batch_pretranspose( cute.arch.barrier() - # Get current batch - gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K) - gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0)) - - # V 方向分 tiles - gSrc = cute.local_tile( - gSrc_batch, (TILE_V, TILE_K), (None, 0) - ) # (TILE_V, TILE_K, num_v_tiles) - - # Partition for load - thr_copy_load = tiled_copy_load.get_slice(tidx) - - # =================================================================== - # Prefetch: All threads participate in cp.async load - # =================================================================== - start_v_tiles = batch_inner * num_v_tiles_per_block - prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) - for v_tiles in range(start_v_tiles, start_v_tiles + prefetch_count): - stage = (v_tiles - start_v_tiles) % NUM_STAGES - - gSrc_tile = gSrc[(None, None, v_tiles)] - sData_stage = sData[(None, None, stage)] - - thr_gSrc = thr_copy_load.partition_S(gSrc_tile) - thr_sData = thr_copy_load.partition_D(sData_stage) - - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + # Compute state index: use pool indexing if enabled + if cutlass.const_expr(use_pool_indexing): + pool_idx = h0_indices[i_n] + state_idx = pool_idx * HV + i_hv + else: + pool_idx = 0 + state_idx = batch_idx + + if pool_idx >= 0: + # Get current batch + gSrc_batch = h0_source[(state_idx, None, None)] # (V, K) + gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (state_idx, None, 0)) + # Tile along V dimension + gSrc = cute.local_tile( + gSrc_batch, (TILE_V, TILE_K), (None, 0) + ) # (TILE_V, TILE_K, num_v_tiles) + + # Partition for load + thr_copy_load = tiled_copy_load.get_slice(tidx) + + # =================================================================== + # Prefetch: All threads participate in cp.async load + # =================================================================== + start_v_tiles = batch_inner * num_v_tiles_per_block + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) + for v_tiles in range(start_v_tiles, start_v_tiles + prefetch_count): + stage = (v_tiles - start_v_tiles) % NUM_STAGES + + gSrc_tile = gSrc[(None, None, v_tiles)] + sData_stage = sData[(None, None, stage)] + + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) - # Load q, k into BF16 registers using autovec_copy (contiguous pattern) - q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - cute.autovec_copy(q_tile, r_q_bf16) - cute.autovec_copy(k_tile, r_k_bf16) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - # Convert BF16 to FP32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = cutlass.Float32(r_q_bf16[i]) - r_k[i] = cutlass.Float32(r_k_bf16[i]) + # Load q, k into BF16 registers using autovec_copy (contiguous pattern) + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) - # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV - v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) - cute.autovec_copy(v_tile, r_v_bf16) - for i in cutlass.range_constexpr(vec_size): - sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + # Convert BF16 to FP32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) - cute.arch.barrier() # Ensure all threads finish writing to sV + # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + + cute.arch.barrier() # Ensure all threads finish writing to sV + + # =================================================================== + # Compute g and beta (scalar values) + # =================================================================== + r_g = 0.0 + r_beta = 0.0 + if lane_id == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + + if beta_x <= softplus_threshold: + # softplus(x) = (1/beta) * log(1 + exp(beta*x)) + # Compute in Float32 + exp_beta_x = cute.exp(beta_x, fastmath=True) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x - # =================================================================== - # Compute g and beta (scalar values) - # =================================================================== - r_g = 0.0 - r_beta = 0.0 - if lane_id == 0: - x = r_a + r_dt_bias - beta_x = softplus_beta * x - softplus_x = 0.0 - - if beta_x <= softplus_threshold: - # softplus(x) = (1/beta) * log(1 + exp(beta*x)) - # Compute in Float32 - exp_beta_x = cute.exp(beta_x, fastmath=True) - log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) - softplus_x = cutlass.Float32( - (cutlass.Float32(1.0) / softplus_beta) * log_result - ) - else: - softplus_x = x + # Compute g = exp(A_log) * softplus_x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x - # Compute g = exp(A_log) * softplus_x - r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + # Compute beta = 1 / (1 + exp(-b)) + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) - # Compute beta = 1 / (1 + exp(-b)) - r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) + # Store to scalar (Float32) + r_g = cute.exp(r_g_value, fastmath=True) - # Store to scalar (Float32) - r_g = cute.exp(r_g_value, fastmath=True) + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) - r_g = cute.arch.shuffle_sync(r_g, 0) - r_beta = cute.arch.shuffle_sync(r_beta, 0) + if use_qk_l2norm: + # Compute L2 norm of q and k + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + # Warp-level reduction using butterfly shuffle + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) - if use_qk_l2norm: - # Compute L2 norm of q and k - sum_q = 0.0 - sum_k = 0.0 - for i in cutlass.range_constexpr(vec_size): - sum_q += r_q[i] * r_q[i] - sum_k += r_k[i] * r_k[i] - # Warp-level reduction using butterfly shuffle - for offset in [16, 8, 4, 2, 1]: - sum_q += cute.arch.shuffle_sync_bfly( - sum_q, offset=offset, mask=-1, mask_and_clamp=31 - ) - sum_k += cute.arch.shuffle_sync_bfly( - sum_k, offset=offset, mask=-1, mask_and_clamp=31 - ) + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k - inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) - inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + # Apply scaling in Float32 for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * inv_norm_q - r_k[i] = r_k[i] * inv_norm_k + r_q[i] = r_q[i] * scale - # Apply scaling in Float32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * scale + # =================================================================== + # Mainloop: All threads participate + # =================================================================== + end_v_tiles = start_v_tiles + num_v_tiles_per_block + for v_tiles in range(start_v_tiles, end_v_tiles): + stage = (v_tiles - start_v_tiles) % NUM_STAGES - # =================================================================== - # Mainloop: All threads participate - # =================================================================== - end_v_tiles = start_v_tiles + num_v_tiles_per_block - for v_tiles in range(start_v_tiles, end_v_tiles): - stage = (v_tiles - start_v_tiles) % NUM_STAGES + # Step 1: Wait for current stage to complete + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() - # Step 1: Wait for current stage to complete - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + # Step 2: Issue async load for next tile (after compute) + next_v_tiles = v_tiles + prefetch_count + if next_v_tiles < end_v_tiles: + next_stage = (next_v_tiles - start_v_tiles) % NUM_STAGES - # Step 2: Issue async load for next tile (after compute) - next_v_tiles = v_tiles + prefetch_count - if next_v_tiles < end_v_tiles: - next_stage = (next_v_tiles - start_v_tiles) % NUM_STAGES + gSrc_next = gSrc[(None, None, next_v_tiles)] + sData_next = sData[(None, None, next_stage)] - gSrc_next = gSrc[(None, None, next_v_tiles)] - sData_next = sData[(None, None, next_stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) - thr_gSrc = thr_copy_load.partition_S(gSrc_next) - thr_sData = thr_copy_load.partition_D(sData_next) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + # Step 3: Compute using data from current stage (contiguous access pattern) + for row in cutlass.range_constexpr(0, TILE_V, 4): + row_offset = tidx // 32 + sum_hk = 0.0 - # Step 3: Compute using data from current stage (contiguous access pattern) - for row in cutlass.range_constexpr(0, TILE_V, 4): - row_offset = tidx // 32 - sum_hk = 0.0 + # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) + ) + cute.autovec_copy(sData_tile, r_h) - # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) - sData_tile = cute.local_tile( - sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) - ) - cute.autovec_copy(sData_tile, r_h) + for i in cutlass.range_constexpr(vec_size): + r_h[i] = r_h[i] * r_g + sum_hk += r_h[i] * r_k[i] - for i in cutlass.range_constexpr(vec_size): - r_h[i] = r_h[i] * r_g - sum_hk += r_h[i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + ) - for offset in [16, 8, 4, 2, 1]: - sum_hk += cute.arch.shuffle_sync_bfly( - sum_hk, offset=offset, mask=-1, mask_and_clamp=31 - ) + v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_new = v_new * r_beta - v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk - v_new = v_new * r_beta + sum_hq = 0.0 + for i in cutlass.range_constexpr(vec_size): + r_h[i] += r_k[i] * v_new + sum_hq += r_h[i] * r_q[i] - sum_hq = 0.0 - for i in cutlass.range_constexpr(vec_size): - r_h[i] += r_k[i] * v_new - sum_hq += r_h[i] * r_q[i] + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) + ) + cute.autovec_copy(r_h, gDst_tile) - # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) - gDst_tile = cute.local_tile( - gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) - ) - cute.autovec_copy(r_h, gDst_tile) + for offset in [16, 8, 4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + ) - for offset in [16, 8, 4, 2, 1]: - sum_hq += cute.arch.shuffle_sync_bfly( - sum_hq, offset=offset, mask=-1, mask_and_clamp=31 - ) + o_idx = v_tiles * TILE_V + row + row_offset + if lane_id == 0 and o_idx < V: + sOutput[o_idx] = cutlass.BFloat16(sum_hq) - o_idx = v_tiles * TILE_V + row + row_offset - if lane_id == 0 and o_idx < V: - sOutput[o_idx] = cutlass.BFloat16(sum_hq) + # =================================================================== + # Final writeback: Copy output from shared memory to global memory + # All threads write (V=128, NUM_THREADS=128) + # =================================================================== + cute.arch.barrier() # Ensure all writes to sOutput are complete + if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: + o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] - # =================================================================== - # Final writeback: Copy output from shared memory to global memory - # All threads write (V=128, NUM_THREADS=128) - # =================================================================== - cute.arch.barrier() # Ensure all writes to sOutput are complete - if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: - o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + else: + start_v_tiles = batch_inner * num_v_tiles_per_block + end_v_tiles = start_v_tiles + num_v_tiles_per_block + if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: + o[(i_n, i_t, i_hv, tidx)] = cutlass.BFloat16(0.0) @cute.kernel @@ -343,6 +359,7 @@ def gdn_decode_kernel_big_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], + use_pool_indexing: cutlass.Constexpr[bool] = False, ): """Each block uses pipeline to load one batch and vectorized writeback""" @@ -400,188 +417,200 @@ def gdn_decode_kernel_big_batch_pretranspose( cute.arch.barrier() - # Get current batch - gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K) - gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0)) - - # V 方向分 tiles - gSrc = cute.local_tile( - gSrc_batch, (TILE_V, TILE_K), (None, 0) - ) # (TILE_V, TILE_K, num_v_tiles) - - # Partition for load - thr_copy_load = tiled_copy_load.get_slice(tidx) + # Compute state index: use pool indexing if enabled + if cutlass.const_expr(use_pool_indexing): + pool_idx = h0_indices[i_n] + state_idx = pool_idx * HV + i_hv + else: + pool_idx = 0 + state_idx = batch_idx + + if pool_idx >= 0: + # Get current batch + gSrc_batch = h0_source[(state_idx, None, None)] # (V, K) + gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (state_idx, None, 0)) + # Tile along V dimension + gSrc = cute.local_tile( + gSrc_batch, (TILE_V, TILE_K), (None, 0) + ) # (TILE_V, TILE_K, num_v_tiles) + + # Partition for load + thr_copy_load = tiled_copy_load.get_slice(tidx) + + # =================================================================== + # Prefetch: All threads participate in cp.async load + # =================================================================== + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles) + for v_tiles in range(prefetch_count): + stage = v_tiles % NUM_STAGES + + gSrc_tile = gSrc[(None, None, v_tiles)] + sData_stage = sData[(None, None, stage)] + + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) - # =================================================================== - # Prefetch: All threads participate in cp.async load - # =================================================================== - prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles) - for v_tiles in range(prefetch_count): - stage = v_tiles % NUM_STAGES + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - gSrc_tile = gSrc[(None, None, v_tiles)] - sData_stage = sData[(None, None, stage)] + # Load q, k into BF16 registers using autovec_copy (contiguous pattern) + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) - thr_gSrc = thr_copy_load.partition_S(gSrc_tile) - thr_sData = thr_copy_load.partition_D(sData_stage) + # Convert BF16 to FP32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + + cute.arch.barrier() # Ensure all threads finish writing to sV + + # =================================================================== + # Compute g and beta (scalar values) + # =================================================================== + r_g = 0.0 + r_beta = 0.0 + if lane_id == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + + if beta_x <= softplus_threshold: + # softplus(x) = (1/beta) * log(1 + exp(beta*x)) + # Compute in Float32 + exp_beta_x = cute.exp(beta_x, fastmath=True) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x - # Load q, k into BF16 registers using autovec_copy (contiguous pattern) - q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) - cute.autovec_copy(q_tile, r_q_bf16) - cute.autovec_copy(k_tile, r_k_bf16) + # Compute g = exp(A_log) * softplus_x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x - # Convert BF16 to FP32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = cutlass.Float32(r_q_bf16[i]) - r_k[i] = cutlass.Float32(r_k_bf16[i]) + # Compute beta = 1 / (1 + exp(-b)) + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) - # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV - v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) - cute.autovec_copy(v_tile, r_v_bf16) - for i in cutlass.range_constexpr(vec_size): - sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + # Store to scalar (Float32) + r_g = cute.exp(r_g_value, fastmath=True) - cute.arch.barrier() # Ensure all threads finish writing to sV + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) - # =================================================================== - # Compute g and beta (scalar values) - # =================================================================== - r_g = 0.0 - r_beta = 0.0 - if lane_id == 0: - x = r_a + r_dt_bias - beta_x = softplus_beta * x - softplus_x = 0.0 - - if beta_x <= softplus_threshold: - # softplus(x) = (1/beta) * log(1 + exp(beta*x)) - # Compute in Float32 - exp_beta_x = cute.exp(beta_x, fastmath=True) - log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) - softplus_x = cutlass.Float32( - (cutlass.Float32(1.0) / softplus_beta) * log_result - ) - else: - softplus_x = x - - # Compute g = exp(A_log) * softplus_x - r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x - - # Compute beta = 1 / (1 + exp(-b)) - r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) - - # Store to scalar (Float32) - r_g = cute.exp(r_g_value, fastmath=True) - - r_g = cute.arch.shuffle_sync(r_g, 0) - r_beta = cute.arch.shuffle_sync(r_beta, 0) + if use_qk_l2norm: + # Compute L2 norm of q and k + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + # Warp-level reduction using butterfly shuffle + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) - if use_qk_l2norm: - # Compute L2 norm of q and k - sum_q = 0.0 - sum_k = 0.0 - for i in cutlass.range_constexpr(vec_size): - sum_q += r_q[i] * r_q[i] - sum_k += r_k[i] * r_k[i] - # Warp-level reduction using butterfly shuffle - for offset in [16, 8, 4, 2, 1]: - sum_q += cute.arch.shuffle_sync_bfly( - sum_q, offset=offset, mask=-1, mask_and_clamp=31 - ) - sum_k += cute.arch.shuffle_sync_bfly( - sum_k, offset=offset, mask=-1, mask_and_clamp=31 - ) + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k - inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) - inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + # Apply scaling in Float32 for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * inv_norm_q - r_k[i] = r_k[i] * inv_norm_k + r_q[i] = r_q[i] * scale - # Apply scaling in Float32 - for i in cutlass.range_constexpr(vec_size): - r_q[i] = r_q[i] * scale + # =================================================================== + # Mainloop: All threads participate + # =================================================================== + for v_tiles in range(num_v_tiles): + stage = v_tiles % NUM_STAGES - # =================================================================== - # Mainloop: All threads participate - # =================================================================== - for v_tiles in range(num_v_tiles): - stage = v_tiles % NUM_STAGES + # Step 1: Wait for current stage to complete + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() - # Step 1: Wait for current stage to complete - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + # Step 2: Issue async load for next tile (after compute) + next_v_tiles = v_tiles + prefetch_count + if next_v_tiles < num_v_tiles: + next_stage = next_v_tiles % NUM_STAGES - # Step 2: Issue async load for next tile (after compute) - next_v_tiles = v_tiles + prefetch_count - if next_v_tiles < num_v_tiles: - next_stage = next_v_tiles % NUM_STAGES + gSrc_next = gSrc[(None, None, next_v_tiles)] + sData_next = sData[(None, None, next_stage)] - gSrc_next = gSrc[(None, None, next_v_tiles)] - sData_next = sData[(None, None, next_stage)] + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) - thr_gSrc = thr_copy_load.partition_S(gSrc_next) - thr_sData = thr_copy_load.partition_D(sData_next) + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + # Step 3: Compute using data from current stage (contiguous access pattern) + for row in cutlass.range_constexpr(0, TILE_V, 4): + row_offset = tidx // 32 + sum_hk = 0.0 - # Step 3: Compute using data from current stage (contiguous access pattern) - for row in cutlass.range_constexpr(0, TILE_V, 4): - row_offset = tidx // 32 - sum_hk = 0.0 + # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) + ) + cute.autovec_copy(sData_tile, r_h) - # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) - sData_tile = cute.local_tile( - sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) - ) - cute.autovec_copy(sData_tile, r_h) + for i in cutlass.range_constexpr(vec_size): + r_h[i] = r_h[i] * r_g + sum_hk += r_h[i] * r_k[i] - for i in cutlass.range_constexpr(vec_size): - r_h[i] = r_h[i] * r_g - sum_hk += r_h[i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + ) - for offset in [16, 8, 4, 2, 1]: - sum_hk += cute.arch.shuffle_sync_bfly( - sum_hk, offset=offset, mask=-1, mask_and_clamp=31 - ) + v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_new = v_new * r_beta - v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk - v_new = v_new * r_beta + sum_hq = 0.0 + for i in cutlass.range_constexpr(vec_size): + r_h[i] += r_k[i] * v_new + sum_hq += r_h[i] * r_q[i] - sum_hq = 0.0 - for i in cutlass.range_constexpr(vec_size): - r_h[i] += r_k[i] * v_new - sum_hq += r_h[i] * r_q[i] + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) + ) + cute.autovec_copy(r_h, gDst_tile) - # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) - gDst_tile = cute.local_tile( - gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) - ) - cute.autovec_copy(r_h, gDst_tile) + for offset in [16, 8, 4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + ) - for offset in [16, 8, 4, 2, 1]: - sum_hq += cute.arch.shuffle_sync_bfly( - sum_hq, offset=offset, mask=-1, mask_and_clamp=31 - ) + o_idx = v_tiles * TILE_V + row + row_offset + if lane_id == 0 and o_idx < V: + sOutput[o_idx] = cutlass.BFloat16(sum_hq) - o_idx = v_tiles * TILE_V + row + row_offset - if lane_id == 0 and o_idx < V: - sOutput[o_idx] = cutlass.BFloat16(sum_hq) + # =================================================================== + # Final writeback: Copy output from shared memory to global memory + # All threads write (V=128, NUM_THREADS=128) + # =================================================================== + cute.arch.barrier() # Ensure all writes to sOutput are complete - # =================================================================== - # Final writeback: Copy output from shared memory to global memory - # All threads write (V=128, NUM_THREADS=128) - # =================================================================== - cute.arch.barrier() # Ensure all writes to sOutput are complete + if tidx < V: + o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] - if tidx < V: - o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] + else: + if tidx < V: + o[(i_n, i_t, i_hv, tidx)] = cutlass.BFloat16(0.0) @cute.jit @@ -609,15 +638,18 @@ def run_gdn_decode_kernel_small_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], - stream: cuda.CUstream, + use_pool_indexing: cutlass.Constexpr[bool] = False, + stream: cuda.CUstream = None, ): """Launch original pipelined kernel for small batch pretranspose.""" - # h0_source: (B*HV, V, K) + # h0_source: (B*HV, V, K) or (pool_size*HV, V, K) when use_pool_indexing=True batch_size, v_dim, k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], h0_source.layout.shape[2], ) + # Grid size: use B*HV (actual batch) not h0_source.shape[0] (which may be pool_size*HV) + grid_batch = B * HV # Create cp.async copy with cache-global mode (bypass L1) copy_atom = cute.make_copy_atom( @@ -680,8 +712,9 @@ def run_gdn_decode_kernel_small_batch_pretranspose( use_initial_state, use_qk_l2norm, is_varlen, + use_pool_indexing, ).launch( - grid=(batch_size * NUM_BLOCKS_PER_STATE, 1, 1), + grid=(grid_batch * NUM_BLOCKS_PER_STATE, 1, 1), block=[NUM_THREADS, 1, 1], smem=smem_bytes, stream=stream, @@ -713,14 +746,15 @@ def run_gdn_decode_kernel_big_batch_pretranspose( use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], - stream: cuda.CUstream, + use_pool_indexing: cutlass.Constexpr[bool] = False, + stream: cuda.CUstream = None, ): - # h0_source: (B*HV, V, K) batch_size, v_dim, k_dim = ( h0_source.layout.shape[0], h0_source.layout.shape[1], h0_source.layout.shape[2], ) + grid_batch = B * HV # Create cp.async copy with cache-global mode (bypass L1) copy_atom = cute.make_copy_atom( @@ -783,8 +817,9 @@ def run_gdn_decode_kernel_big_batch_pretranspose( use_initial_state, use_qk_l2norm, is_varlen, + use_pool_indexing, ).launch( - grid=(batch_size, 1, 1), + grid=(grid_batch, 1, 1), block=[NUM_THREADS, 1, 1], smem=smem_bytes, stream=stream, @@ -807,6 +842,7 @@ def _get_compiled_decode_kernel( dtype: torch.dtype, scale: float, use_qk_l2norm: bool, + use_pool_indexing: bool = False, ): """Cache compiled kernel for given configuration (pretranspose version).""" # This will be populated on first call @@ -831,33 +867,53 @@ def run_pretranspose_decode( V: int, scale: float, use_qk_l2norm: bool, + use_pool_indexing: bool = False, + initial_state_indices: Optional[torch.Tensor] = None, ): """Compile and execute the pretranspose decode kernel. Args: - h0_source: State tensor reshaped to [B*HV, V, K]. + h0_source: State tensor reshaped to [B*HV, V, K], or [pool_size*HV, V, K] + when use_pool_indexing=True. A_log, a, dt_bias, q, k, v, b: Input tensors. output: Pre-allocated output tensor [B, T, HV, V]. B, T, H, HV, K, V: Dimension sizes. scale: Query scale factor. use_qk_l2norm: Whether to apply L2 normalization. + use_pool_indexing: Whether to use pool-based indirect state indexing. + initial_state_indices: Int32 indices into state pool, shape [B]. + Negative values indicate padding (kernel writes zeros). """ # Compile kernel with TVM FFI (cached) - cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm) + cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm, use_pool_indexing) cache = _get_compiled_decode_kernel(*cache_key) # Get or create h0_indices and cu_seqlens (cached per config) if "h0_indices" not in cache or cache["h0_indices"].device != q.device: cache["h0_indices"] = torch.zeros(B, dtype=torch.int32, device=q.device) cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) - h0_indices = cache["h0_indices"] + + if use_pool_indexing and initial_state_indices is not None: + h0_indices = initial_state_indices.to(torch.int32) + else: + h0_indices = cache["h0_indices"] cu_seqlens = cache["cu_seqlens"] if "compiled" not in cache: stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Convert tensors to CuTe format for compilation only - h0_source_tensor = from_dlpack(h0_source, assumed_align=16) + if use_pool_indexing: + # Use symbolic pool dimension so kernel compiles with dynamic pool_size + sym_pool_batch = cute.sym_int() + h0_source_tensor = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + (sym_pool_batch, V, K), + stride_order=(2, 1, 0), + assumed_align=16, + ) + else: + h0_source_tensor = from_dlpack(h0_source, assumed_align=16) A_log_tensor = from_dlpack(A_log, assumed_align=16) a_tensor = from_dlpack(a, assumed_align=16) dt_bias_tensor = from_dlpack(dt_bias, assumed_align=16) @@ -897,6 +953,7 @@ def run_pretranspose_decode( V=V, use_initial_state=True, use_qk_l2norm=use_qk_l2norm, + use_pool_indexing=use_pool_indexing, is_varlen=False, stream=stream, options="--enable-tvm-ffi", diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index e617a83bd2..0d764a7fc2 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -116,7 +116,7 @@ def _test_decode_kernel_pretranspose( A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # dt_bias: decay bias [HV] - dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # a: input-dependent decay [B, 1, HV] # Convert alpha to a: alpha = exp(-exp(A_log) * softplus(a + dt_bias)) @@ -284,7 +284,7 @@ def _test_decode_kernel_nontranspose( A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # dt_bias: decay bias [HV] - dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # a: input-dependent decay [B, 1, HV] a = ( @@ -440,7 +440,7 @@ def _test_decode_kernel_pretranspose_pool( v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 - dt_bias = torch.randn(num_sab_heads, dtype=dtype_torch) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) @@ -508,6 +508,7 @@ def _test_decode_kernel_pretranspose_pool( ) +@pytest.mark.parametrize("state_dtype", ["bfloat16", "float32"]) @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) @@ -521,9 +522,292 @@ def test_decode_kernel_pretranspose_pool( head_size: int, batch_size: int, scale: float, + state_dtype: str, seed: int = int(os.environ.get("SEED", "0")), ): _test_decode_kernel_pretranspose_pool( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + scale, + state_dtype=state_dtype, + seed=seed, + ) + + +# ============================================================================ +# Test pretranspose kernel pool + indices with negative indices (padding) +# +# Negative pool indices signal padding slots: the kernel must write zeros to +# output for those batch elements and leave the pool untouched. The gather → +# direct-state reference cannot handle negative indices, so we compare valid +# slots against the Python reference and verify padding semantics directly. +# +# Only float32 state is tested because the bf16 fast-path kernel does not +# support negative indices. +# ============================================================================ + + +def _test_decode_kernel_pretranspose_pool_negative_indices( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + pool_multiplier: int = 2, + padding_fraction: float = 0.2, + seed: int | None = None, +): + """Pool+indices with negative indices must zero output for padding slots + and match the Python reference for valid slots.""" + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + pool_size = batch_size * pool_multiplier + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.nn.functional.normalize( + torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch), + p=2.0, + dim=-1, + ) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + + # Float32 state pool (only f32 CuTe DSL kernel supports negative indices) + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=torch.float32 + ) + + # Build indices with ~padding_fraction padding slots + indices = torch.arange(batch_size, dtype=torch.int32, device=device) + mask = torch.rand(batch_size, device=device) < padding_fraction + # Ensure at least one valid and one padding slot when batch_size >= 2 + if batch_size >= 2: + mask[0] = False # first slot valid + mask[-1] = True # last slot padding + indices[mask] = -1 + + # Map valid indices to random non-contiguous pool slots + valid_mask = indices >= 0 + num_valid = valid_mask.sum().item() + if num_valid > 0: + valid_slots = torch.randperm(pool_size, device=device)[:num_valid].to( + torch.int32 + ) + indices[valid_mask] = valid_slots + + pool_under_test = pool.clone() + out_pool, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=indices, + ) + torch.cuda.synchronize() + + # ── Padding slots must produce zero output ──────────────────────────────── + invalid_mask = indices < 0 + if invalid_mask.any(): + padded_output = out_pool[invalid_mask] + assert torch.all(padded_output == 0), ( + f"Padding slots must produce zero output, " + f"but got max abs = {padded_output.abs().max().item()}" + ) + + # ── Valid slots: compare per-sample against Python reference ────────────── + valid_indices_local = torch.where(valid_mask)[0].cpu().numpy() + pool_snapshot = pool.clone() # original pool before kernel + + for i in valid_indices_local: + pool_idx = indices[i].item() + ref_o, ref_s = decode_delta_rule( + q[i].squeeze(0).unsqueeze(0).float(), # [1, H, K] + k[i].squeeze(0).unsqueeze(0).float(), + v[i].squeeze(0).unsqueeze(0).float(), + pool_snapshot[pool_idx] + .float() + .transpose(-2, -1) + .contiguous() + .unsqueeze(0), # [1, HV, K, V] + A_log=A_log, + a=a[i].squeeze(0).unsqueeze(0), # [1, HV] + dt_bias=dt_bias.float(), + b=b[i].squeeze(0).unsqueeze(0), + scale_factor=scale, + use_l2_norm=True, + ) + # Output + torch.testing.assert_close( + out_pool[i].float().squeeze(0), + ref_o.squeeze(0).to(device), + atol=1e-2, + rtol=1e-2, + ) + # State update (kernel: [HV, V, K], ref: [1, HV, K, V]) + torch.testing.assert_close( + pool_under_test[pool_idx].float(), + ref_s.squeeze(0).transpose(-2, -1).to(device), + atol=1e-2, + rtol=1e-2, + ) + + # ── Untouched pool slots must remain unchanged ──────────────────────────── + used_pool_indices = indices[valid_mask].unique() + touched = torch.zeros(pool_size, dtype=torch.bool, device=device) + if len(used_pool_indices) > 0: + touched[used_pool_indices.long()] = True + torch.testing.assert_close( + pool_under_test[~touched], pool[~touched], atol=0.0, rtol=0.0 + ) + + print( + f"✓ Pool+indices negative-index test passed " + f"(batch={batch_size}, pool={pool_size}, dtype={dtype})" + ) + + +def _test_decode_kernel_pretranspose_pool_all_padding( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + pool_multiplier: int = 2, + seed: int | None = None, +): + """When ALL indices are negative (entire batch is padding), output must be + all zeros and the pool must remain completely unchanged.""" + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + pool_size = batch_size * pool_multiplier + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + + A_log = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) * 0.1 + b = torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch) + + pool = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=torch.float32 + ) + indices = torch.full((batch_size,), -1, dtype=torch.int32, device=device) + + pool_under_test = pool.clone() + out, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=indices, + ) + torch.cuda.synchronize() + + assert torch.all(out == 0), ( + f"All-padding batch must produce zero output, " + f"but got max abs = {out.abs().max().item()}" + ) + torch.testing.assert_close( + pool_under_test, + pool, + atol=0.0, + rtol=0.0, + msg="All-padding batch must not modify any pool state", + ) + + print( + f"✓ Pool+indices all-padding test passed " + f"(batch={batch_size}, pool={pool_size}, dtype={dtype})" + ) + + +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_decode_kernel_pretranspose_pool_negative_indices( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + scale: float, + seed: int = int(os.environ.get("SEED", "0")), +): + _test_decode_kernel_pretranspose_pool_negative_indices( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + scale, + seed=seed, + ) + + +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_decode_kernel_pretranspose_pool_all_padding( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + scale: float, + seed: int = int(os.environ.get("SEED", "0")), +): + _test_decode_kernel_pretranspose_pool_all_padding( dtype, batch_size, num_q_heads, @@ -1161,6 +1445,30 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( seed=42, ) + print("\n=== Testing Pool+indices with negative indices ===") + _test_decode_kernel_pretranspose_pool_negative_indices( + dtype="bfloat16", + batch_size=8, + num_q_heads=16, + num_k_heads=16, + num_v_heads=32, + head_size=128, + scale=1.0, + seed=42, + ) + + print("\n=== Testing Pool+indices all-padding ===") + _test_decode_kernel_pretranspose_pool_all_padding( + dtype="bfloat16", + batch_size=8, + num_q_heads=16, + num_k_heads=16, + num_v_heads=32, + head_size=128, + scale=1.0, + seed=42, + ) + print("\n=== Testing IMPROVED CuTe-DSL version (T=1,2,3,4) ===") if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: for t in [1, 2, 3, 4]: