From a5a2bac103c829b64e67d2acd53dc1841e3760f4 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Wed, 4 Feb 2026 15:29:07 -0800 Subject: [PATCH 01/11] Add Gated Delta Rule CuTe-DSL kernel for decode-phase inference Implements high-performance Gated Delta Rule linear attention kernel supporting fixed sequence lengths T=1, T=2, T=3, T=4 using NVIDIA CuTe-DSL. Key features: - H state layout: K-last [B, HV, V, K] where K is the contiguous (fastest) dimension - Unified kernel architecture: T=2/3/4 share a single compile-time specialized kernel via Constexpr dispatch; T=1 uses separate kernel with persistent K optimization - L2-normalized Q/K with configurable scale - Gated exponential decay via softplus - Delta rule updates: v_delta = beta * (v - pred) - Bank-conflict-free cross-warp reductions - Async H memory loading with aggressive pipelining - BF16 tensors with FP32 compute for numerical stability - GQA (grouped-query attention) support Also includes: - benchmark_gated_delta_rule.py: Simple benchmark script for measuring kernel perf - Updated __init__.py exports Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- flashinfer/cute_dsl/__init__.py | 7 + .../cute_dsl/benchmark_gated_delta_rule.py | 158 ++ flashinfer/cute_dsl/gated_delta_rule.py | 1825 +++++++++++++++++ 3 files changed, 1990 insertions(+) create mode 100644 flashinfer/cute_dsl/benchmark_gated_delta_rule.py create mode 100644 flashinfer/cute_dsl/gated_delta_rule.py diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 31c5120d6c..8adf300ecf 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -35,6 +35,10 @@ add_rmsnorm_fp4quant, AddRMSNormFP4QuantKernel, ) + from .gated_delta_rule import ( + gated_delta_rule, + GatedDeltaRuleKernel, + ) __all__ = [ # Utils (always available) @@ -56,4 +60,7 @@ # Add + RMSNorm + FP4 Quantization "add_rmsnorm_fp4quant", "AddRMSNormFP4QuantKernel", + # Gated Delta Rule + "gated_delta_rule", + "GatedDeltaRuleKernel", ] diff --git a/flashinfer/cute_dsl/benchmark_gated_delta_rule.py b/flashinfer/cute_dsl/benchmark_gated_delta_rule.py new file mode 100644 index 0000000000..63be0ad1ff --- /dev/null +++ b/flashinfer/cute_dsl/benchmark_gated_delta_rule.py @@ -0,0 +1,158 @@ +""" +Benchmark: Gated Delta Rule CuTe-DSL Kernel + +Simple benchmark showing duration across batch sizes and sequence lengths (T=1,2,3,4). +""" + +import math +import statistics +import torch + + +def get_l2_cache_size(): + """Get L2 cache size in bytes for the current GPU.""" + return torch.cuda.get_device_properties(0).L2_cache_size + + +def benchmark(func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True): + """ + Benchmark a kernel with L2 flushing and return median time in microseconds. + + Args: + func: Function to benchmark + num_iterations: Number of timed iterations + n_warmup: Number of warmup iterations + flush_l2: Whether to flush L2 cache before each iteration + use_dummy_matmul: Whether to use dummy matmul for short-lived kernels + """ + l2_size = get_l2_cache_size() + cache_flush = torch.empty(l2_size, dtype=torch.uint8, device="cuda") + + # Dummy matmul for short-lived kernels (fills GPU pipeline so CUDA events record properly) + if use_dummy_matmul: + A = torch.randn(4096, 4096, dtype=torch.float32, device="cuda") + B = torch.randn(4096, 4096, dtype=torch.float32, device="cuda") + _ = A @ B # Warm up cuBLAS + + # Warmup + for _ in range(n_warmup): + if flush_l2: + cache_flush.zero_() + func() + torch.cuda.synchronize() + + # Benchmark + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)] + + for i in range(num_iterations): + if flush_l2: + cache_flush.zero_() + if use_dummy_matmul: + _ = A @ B # Dummy work to ensure events record properly for short kernels + start_events[i].record() + func() + end_events[i].record() + + torch.cuda.synchronize() + times_us = [s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events)] + return statistics.median(times_us) + + +def create_inputs(B, T, H=16, HV=32, K=128, V=128): + """Create test inputs.""" + return { + "q": torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16), + "k": torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16), + "v": torch.randn(B, T, HV, V, device="cuda", dtype=torch.bfloat16), + "a": torch.randn(B, T, HV, device="cuda", dtype=torch.bfloat16) * 0.1, + "b": torch.randn(B, T, HV, device="cuda", dtype=torch.bfloat16), + "A_log": torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1, + "dt_bias": torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1, + "state": torch.randn(B, HV, V, K, device="cuda", dtype=torch.bfloat16), + "scale": 1.0 / math.sqrt(K), + } + + +def main(): + from gated_delta_rule import gated_delta_rule + + print("=" * 70) + print("Gated Delta Rule CuTe-DSL Kernel Benchmark") + print("Config: H=16, HV=32, K=128, V=128, bfloat16") + print("=" * 70) + + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + seqlens = [1, 2, 3, 4] + num_iterations = 100 + + # Results storage + results = {T: {} for T in seqlens} + + # Benchmark each configuration + for T in seqlens: + print(f"\nCompiling and benchmarking T={T}...") + for B in batch_sizes: + inputs = create_inputs(B, T) + state = inputs["state"].clone() + + # Warmup / compile + _ = gated_delta_rule( + A_log=inputs["A_log"], + a=inputs["a"], + dt_bias=inputs["dt_bias"], + q=inputs["q"], + k=inputs["k"], + v=inputs["v"], + b=inputs["b"], + initial_state_source=state, + scale=inputs["scale"], + ) + + def run_kernel(): + return gated_delta_rule( + A_log=inputs["A_log"], + a=inputs["a"], + dt_bias=inputs["dt_bias"], + q=inputs["q"], + k=inputs["k"], + v=inputs["v"], + b=inputs["b"], + initial_state_source=state, + scale=inputs["scale"], + ) + + time_us = benchmark(run_kernel, num_iterations=num_iterations, flush_l2=True, use_dummy_matmul=True) + results[T][B] = time_us + print(f" B={B:>3}: {time_us:>7.1f} us") + + # Summary table + print("\n" + "=" * 70) + print("SUMMARY: Duration (us) by Batch Size and Sequence Length") + print("=" * 70) + + # Header + header = f"{'B':>6} |" + for T in seqlens: + header += f" T={T} |" + print(header) + print("-" * 70) + + # Data rows + for B in batch_sizes: + row = f"{B:>6} |" + for T in seqlens: + row += f" {results[T][B]:>7.1f} |" + print(row) + + print("-" * 70) + + # Averages + print("\nAverage duration per T:") + for T in seqlens: + avg = sum(results[T].values()) / len(results[T]) + print(f" T={T}: {avg:.1f} us") + + +if __name__ == "__main__": + main() diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py new file mode 100644 index 0000000000..dc00c0a5ec --- /dev/null +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -0,0 +1,1825 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Gated Delta Rule Kernel (Unified Modular) using CuTe-DSL +======================================================== + +High-performance CUDA kernel implementing the Gated Delta Rule linear attention +mechanism for decode-phase inference, supporting sequence lengths T=1, T=2, T=3, T=4. + +Key Features: +- Unified kernel architecture: T=2/3/4 share a single compile-time specialized kernel + using Constexpr dispatch, while T=1 uses a separate kernel with persistent K-in-registers +- L2-normalized Q/K with configurable scale +- Gated exponential decay of hidden state H via softplus +- Delta rule updates: v_delta = beta * (v - pred) +- Bank-conflict-free cross-warp reductions +- Async H memory loading with aggressive pipelining +- BF16 tensors with FP32 compute for numerical stability +- GQA (grouped-query attention) support with configurable H (query) and HV (value) heads +""" + +import math +from typing import Optional + +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda +import torch +from cutlass import utils +from cutlass._mlir.dialects import nvvm +from cutlass.cute.runtime import from_dlpack + +# ============================================================================== +# CONSTANTS +# ============================================================================== +H_SMEM_PADDING = 8 +H_SMEM_STRIDE = 128 + H_SMEM_PADDING + + +# ============================================================================== +# SHARED HELPER FUNCTIONS +# ============================================================================== + + +@cute.jit +def write_h_chunk_to_smem(h_chunk_f32, h_sh_chunk, lane_idx, k_base): + """Write F32 register H chunk to BF16 SMEM.""" + for i in cutlass.range_constexpr(32): + h_sh_chunk[lane_idx, k_base + i] = h_chunk_f32[i].to(cutlass.BFloat16) + + +@cute.jit +def store_h_smem_to_gmem(h_sh_chunk, h_out, tidx, v_row_offset): + """Store H from SMEM to GMEM using 128-bit stores.""" + copy_bits = 128 + copy_elems = copy_bits // cutlass.BFloat16.width + + thr_layout = cute.make_layout((16, 8), stride=(8, 1)) + val_layout = cute.make_layout((1, copy_elems)) + + from cutlass.cute.nvgpu import CopyUniversalOp + + atom_store = cute.make_copy_atom( + CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=copy_bits + ) + tiled_copy = cute.make_tiled_copy_tv(atom_store, thr_layout, val_layout) + thr_copy = tiled_copy.get_slice(tidx) + + for row_iter in cutlass.range_constexpr(2): + for col_iter in cutlass.range_constexpr(2): + s_tile = cute.local_tile(h_sh_chunk, (16, 64), (row_iter, col_iter)) + g_tile = cute.local_tile( + h_out, (16, 64), (row_iter + (v_row_offset // 16), col_iter) + ) + tS = thr_copy.partition_S(s_tile) + tD = thr_copy.partition_D(g_tile) + cute.copy(atom_store, tS, tD) + + +@cute.jit +def load_h_chunk_async(h_sh_chunk, h_global, tidx, row_offset): + """Load H chunk from GMEM to SMEM using async copy.""" + copy_bits = 128 + copy_elems = copy_bits // cutlass.BFloat16.width + + thr_layout = cute.make_layout((16, 8), stride=(8, 1)) + val_layout = cute.make_layout((1, copy_elems)) + + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), cutlass.BFloat16, num_bits_per_copy=copy_bits + ) + tiled_copy = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + thr_copy = tiled_copy.get_slice(tidx) + + for row_iter in cutlass.range_constexpr(2): + for col_iter in cutlass.range_constexpr(2): + g_tile = cute.local_tile( + h_global, (16, 64), (row_iter + (row_offset // 16), col_iter) + ) + s_tile = cute.local_tile(h_sh_chunk, (16, 64), (row_iter, col_iter)) + tS = thr_copy.partition_S(g_tile) + tD = thr_copy.partition_D(s_tile) + cute.copy(atom_async_copy, tS, tD) + + +@cute.jit +def compute_single_gate( + alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold +): + """Compute gate values (g_exp, beta) for a single token.""" + x = alpha + dt_bias_val + beta_x = softplus_beta * x + softplus_x = cutlass.Float32(0.0) + if beta_x <= softplus_threshold: + softplus_x = (cutlass.Float32(1.0) / softplus_beta) * cute.math.log( + cutlass.Float32(1.0) + cute.math.exp(beta_x) + ) + else: + softplus_x = x + g = -cute.math.exp(A_log_val) * softplus_x + g_exp = cute.math.exp(g) + beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.math.exp(-beta_raw) + ) + return g_exp, beta + + +@cute.jit +def normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps): + """L2-normalize Q and K vectors, then store to shared memory.""" + q_reg = cute.make_rmem_tensor((4,), cutlass.Float32) + k_reg = cute.make_rmem_tensor((4,), cutlass.Float32) + + for i in cutlass.range_constexpr(4): + q_reg[i] = q_head[lane_idx + i * 32].to(cutlass.Float32) + k_reg[i] = k_head[lane_idx + i * 32].to(cutlass.Float32) + + q_sum_sq = cutlass.Float32(0.0) + k_sum_sq = cutlass.Float32(0.0) + q_sum_sq2 = cutlass.Float32(0.0) + k_sum_sq2 = cutlass.Float32(0.0) + + for i in cutlass.range_constexpr(0, 4, 2): + q_sum_sq, q_sum_sq2 = cute.arch.fma_packed_f32x2( + src_a=(q_reg[i], q_reg[i + 1]), + src_b=(q_reg[i], q_reg[i + 1]), + src_c=(q_sum_sq, q_sum_sq2), + ) + k_sum_sq, k_sum_sq2 = cute.arch.fma_packed_f32x2( + src_a=(k_reg[i], k_reg[i + 1]), + src_b=(k_reg[i], k_reg[i + 1]), + src_c=(k_sum_sq, k_sum_sq2), + ) + + q_sum_sq = q_sum_sq + q_sum_sq2 + k_sum_sq = k_sum_sq + k_sum_sq2 + + for i in cutlass.range_constexpr(5): + q_sum_sq = q_sum_sq + cute.arch.shuffle_sync_bfly( + q_sum_sq, offset=1 << i, mask=0xFFFFFFFF + ) + k_sum_sq = k_sum_sq + cute.arch.shuffle_sync_bfly( + k_sum_sq, offset=1 << i, mask=0xFFFFFFFF + ) + + q_norm = cutlass.Float32(1.0) / cute.math.sqrt(q_sum_sq + eps) + k_norm = cutlass.Float32(1.0) / cute.math.sqrt(k_sum_sq + eps) + q_scale_factor = q_norm * scale + + for i in cutlass.range_constexpr(4): + q_sh[lane_idx + i * 32] = q_reg[i] * q_scale_factor + k_sh[lane_idx + i * 32] = k_reg[i] * k_norm + + +@cute.jit +def load_v_to_smem(v_head, v_sh, tidx): + """Load V values from GMEM to SMEM.""" + v_sh[tidx] = v_head[tidx].to(cutlass.Float32) + + +@cute.jit +def load_kq_chunk_from_smem(kq_sh, kq_chunk, k_base): + """Load K or Q chunk from SMEM to registers.""" + for i in cutlass.range_constexpr(32): + kq_chunk[i] = kq_sh[k_base + i] + + +@cute.jit +def decay_h_from_smem_and_compute_pred( + h_sh_chunk, h_chunk, kq_chunk, g_exp, lane_idx, k_base +): + """Load H from SMEM, apply decay, and compute pred = sum_k(h * k).""" + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(kq_chunk[i], kq_chunk[i + 1]), + src_c=(pred, pred2), + ) + + pred = pred + pred2 + return pred + + +@cute.jit +def update_h_with_delta(h_chunk, kq_chunk, v_delta): + """Update H with delta: h = h + k * v_delta.""" + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(kq_chunk[i], kq_chunk[i + 1]), + src_b=(v_delta, v_delta), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + +@cute.jit +def compute_output(h_chunk, kq_chunk): + """Compute output = sum_k(h * q).""" + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(kq_chunk[i], kq_chunk[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + return out + + +@cute.jit +def decay_h_in_place(h_chunk, g_exp): + """Apply decay to H in place: h = h * g_exp.""" + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + + +@cute.jit +def cross_warp_reduce_single(reduce_sh, slot, warp_idx, lane_idx, value): + """ + Cross-warp reduction for a single value using bank-conflict-free layout. + Layout: [slot, lane_idx, warp_idx] + """ + reduce_sh[slot, lane_idx, warp_idx] = value + cute.arch.sync_threads() + reduced_value = ( + reduce_sh[slot, lane_idx, 0] + + reduce_sh[slot, lane_idx, 1] + + reduce_sh[slot, lane_idx, 2] + + reduce_sh[slot, lane_idx, 3] + ) + return reduced_value + + +@cute.jit +def cross_warp_reduce_two(reduce_sh, slot1, slot2, warp_idx, lane_idx, value1, value2): + """ + Cross-warp reduction for two values simultaneously using bank-conflict-free layout. + Layout: [slot, lane_idx, warp_idx] + """ + reduce_sh[slot1, lane_idx, warp_idx] = value1 + reduce_sh[slot2, lane_idx, warp_idx] = value2 + cute.arch.sync_threads() + reduced1 = ( + reduce_sh[slot1, lane_idx, 0] + + reduce_sh[slot1, lane_idx, 1] + + reduce_sh[slot1, lane_idx, 2] + + reduce_sh[slot1, lane_idx, 3] + ) + reduced2 = ( + reduce_sh[slot2, lane_idx, 0] + + reduce_sh[slot2, lane_idx, 1] + + reduce_sh[slot2, lane_idx, 2] + + reduce_sh[slot2, lane_idx, 3] + ) + return reduced1, reduced2 + + +@cute.jit +def process_first_token( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh, + q_sh, + v_sh, + reduce_sh, + o_head, + g_exp, + beta, + v_offset, + pred_slot, + warp_idx, + lane_idx, + k_base, +): + """ + Process the first token in a V-chunk (T=0). + - Load K from SMEM + - Decay H from SMEM and compute pred + - Cross-warp reduce pred (uses pred_slot) + - Update H with delta + - Load Q and compute output + Returns: out (partial output, not yet reduced) + """ + # Load K for this token + load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + + # Decay H from SMEM and compute pred = H * K + pred = decay_h_from_smem_and_compute_pred( + h_sh_chunk_curr, h_chunk, kq_chunk, g_exp, lane_idx, k_base + ) + + # Reduce pred across warps (slot 0 for first token) + pred_final = cross_warp_reduce_single( + reduce_sh, pred_slot, warp_idx, lane_idx, pred + ) + + # Compute delta and update H + v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta + update_h_with_delta(h_chunk, kq_chunk, v_delta) + + # Load Q and compute output + load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) + out = compute_output(h_chunk, kq_chunk) + + return out + + +@cute.jit +def process_middle_token( + h_chunk, + kq_chunk, + k_sh, + q_sh, + v_sh, + reduce_sh, + o_head_prev, + g_exp, + beta, + v_offset, + out_slot_prev, + pred_slot, + out_prev, + warp_idx, + lane_idx, + k_base, +): + """ + Process a middle token (T=1, T=2 for T=4 kernel). + - Decay H in place + - Load K, compute pred + - Joint reduction of (prev_out, this_pred) + - Store prev output + - Update H with delta + - Load Q and compute output + Returns: out (partial output, not yet reduced) + """ + # Decay H in place + decay_h_in_place(h_chunk, g_exp) + + # Load K and compute pred + load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + pred = compute_output(h_chunk, kq_chunk) + + # Joint reduction: reduce out_prev and pred together + out_prev_final, pred_final = cross_warp_reduce_two( + reduce_sh, out_slot_prev, pred_slot, warp_idx, lane_idx, out_prev, pred + ) + + # Store previous token's output + if warp_idx == 0: + o_head_prev[v_offset + lane_idx] = out_prev_final.to(cutlass.BFloat16) + + # Compute delta and update H + v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta + update_h_with_delta(h_chunk, kq_chunk, v_delta) + + # Load Q and compute output + load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) + out = compute_output(h_chunk, kq_chunk) + + return out + + +@cute.jit +def process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh, + q_sh, + v_sh, + reduce_sh, + o_head_prev, + o_head_last, + g_exp, + beta, + v_offset, + out_slot_prev, + pred_slot, + out_slot_last, + out_prev, + warp_idx, + lane_idx, + k_base, +): + """ + Process the last token and finalize the V-chunk. + - Decay H in place + - Load K, compute pred + - Joint reduction of (prev_out, this_pred) + - Store prev output + - Update H with delta + - Compute last output and reduce + - Write H back to SMEM + - Store last output + """ + # Decay H in place + decay_h_in_place(h_chunk, g_exp) + + # Load K and compute pred + load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + pred = compute_output(h_chunk, kq_chunk) + + # Joint reduction: reduce out_prev and pred together + out_prev_final, pred_final = cross_warp_reduce_two( + reduce_sh, out_slot_prev, pred_slot, warp_idx, lane_idx, out_prev, pred + ) + + # Store previous token's output + if warp_idx == 0: + o_head_prev[v_offset + lane_idx] = out_prev_final.to(cutlass.BFloat16) + + # Compute delta and update H + v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta + update_h_with_delta(h_chunk, kq_chunk, v_delta) + + # Compute last output + load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) + out_last = compute_output(h_chunk, kq_chunk) + + # Final reduction and store + out_last_final = cross_warp_reduce_single( + reduce_sh, out_slot_last, warp_idx, lane_idx, out_last + ) + write_h_chunk_to_smem(h_chunk, h_sh_chunk_curr, lane_idx, k_base) + if warp_idx == 0: + o_head_last[v_offset + lane_idx] = out_last_final.to(cutlass.BFloat16) + + +# ============================================================================== +# UNIFIED V-CHUNK PROCESSING FOR SEQLEN=2/3/4 +# ============================================================================== + + +@cute.jit +def process_vchunk_unified_234( + h_sh_chunk_curr, + h_sh_chunk_prev, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + v_offset, + prev_v_offset, + store_prev, + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS: cutlass.Constexpr[int], +): + """ + Unified V-chunk processing for 2, 3, or 4 tokens using Constexpr parameter. + + This function handles V-chunk processing for all multi-token cases (T=2, T=3, T=4) + using compile-time specialization via NUM_TOKENS. + + Pattern: + - Token 0: First token (always) + - Tokens 1 to NUM_TOKENS-2: Middle tokens (compile-time unrolled) + - Token NUM_TOKENS-1: Last token (always) + """ + # Store previous H chunk if needed + if store_prev: + store_h_smem_to_gmem(h_sh_chunk_prev, h_out, tidx, prev_v_offset) + + # Token 0: First token processing (always executed) + out0 = process_first_token( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh0, + q_sh0, + v_sh0, + reduce_sh, + o_head0, + g_exp0, + beta0, + v_offset, + 0, # pred_slot=0 + warp_idx, + lane_idx, + k_base, + ) + + # Compile-time dispatch based on NUM_TOKENS + if NUM_TOKENS == 2: + # For T=2: Token 1 is the last token + process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh1, + q_sh1, + v_sh1, + reduce_sh, + o_head0, + o_head1, + g_exp1, + beta1, + v_offset, + 1, + 2, + 3, # out_slot_prev=1, pred_slot=2, out_slot_last=3 + out0, + warp_idx, + lane_idx, + k_base, + ) + elif NUM_TOKENS == 3: + # For T=3: Token 1 is middle, Token 2 is last + out1 = process_middle_token( + h_chunk, + kq_chunk, + k_sh1, + q_sh1, + v_sh1, + reduce_sh, + o_head0, + g_exp1, + beta1, + v_offset, + 1, + 2, # out_slot_prev=1, pred_slot=2 + out0, + warp_idx, + lane_idx, + k_base, + ) + process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh2, + q_sh2, + v_sh2, + reduce_sh, + o_head1, + o_head2, + g_exp2, + beta2, + v_offset, + 3, + 4, + 5, # out_slot_prev=3, pred_slot=4, out_slot_last=5 + out1, + warp_idx, + lane_idx, + k_base, + ) + else: + # For T=4: Tokens 1,2 are middle, Token 3 is last + out1 = process_middle_token( + h_chunk, + kq_chunk, + k_sh1, + q_sh1, + v_sh1, + reduce_sh, + o_head0, + g_exp1, + beta1, + v_offset, + 1, + 2, # out_slot_prev=1, pred_slot=2 + out0, + warp_idx, + lane_idx, + k_base, + ) + out2 = process_middle_token( + h_chunk, + kq_chunk, + k_sh2, + q_sh2, + v_sh2, + reduce_sh, + o_head1, + g_exp2, + beta2, + v_offset, + 3, + 4, # out_slot_prev=3, pred_slot=4 + out1, + warp_idx, + lane_idx, + k_base, + ) + # Last token for NUM_TOKENS=4: Token 3 + process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh3, + q_sh3, + v_sh3, + reduce_sh, + o_head2, + o_head3, + g_exp3, + beta3, + v_offset, + 5, + 6, + 7, # out_slot_prev=5, pred_slot=6, out_slot_last=7 + out2, + warp_idx, + lane_idx, + k_base, + ) + + +# ============================================================================== +# SEQLEN=1 KERNEL (Persistent K Optimization) +# ============================================================================== + + +@cute.kernel +def gated_delta_rule_decode_kernel_seqlen1( + gQ: cute.Tensor, + gK: cute.Tensor, + gV: cute.Tensor, + ga: cute.Tensor, + gb: cute.Tensor, + gA_log: cute.Tensor, + gdt_bias: cute.Tensor, + gH: cute.Tensor, + gO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, +): + """ + Seqlen=1 kernel with persistent K optimization. + OPTIMIZATIONS: + 1. PERSISTENT K IN REGISTERS ONLY: K[k_base:k_base+32] kept for entire kernel + Q is reloaded per chunk (lower register pressure than V3) + 2. AGGRESSIVE PIPELINING: Load chunks 2 ahead, store during next compute + 3. [4,32] CROSS-WARP REDUCTION: Correct lane-preserving reduction + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + HV = cutlass.Int32(gV.shape[2]) + H = cutlass.Int32(gQ.shape[2]) + + batch_idx = bidx // HV + value_head_idx = bidx % HV + query_head_idx = value_head_idx // (HV // H) + + smem = utils.SmemAllocator() + + # Compute gates using shared helper + alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + A_log_val = gA_log[value_head_idx] + dt_bias_val = gdt_bias[value_head_idx] + g_exp, beta = compute_single_gate( + alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Allocate SMEM + h_sh_chunk0 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk1 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk2 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk3 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + + q_sh = smem.allocate_tensor(cutlass.Float32, 128) + k_sh = smem.allocate_tensor(cutlass.Float32, 128) + + pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + + h_global = gH[(batch_idx, value_head_idx, None, None)] + + # Launch first 2 async loads + load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) + nvvm.cp_async_commit_group() + + # L2 normalization + q_head = gQ[(batch_idx, 0, query_head_idx, None)] + k_head = gK[(batch_idx, 0, query_head_idx, None)] + + warp_idx = tidx // 32 + lane_idx = tidx % 32 + + # Use shared helper for Q/K normalization (only warp 0 does the work) + if warp_idx == 0: + normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) + + cute.arch.sync_threads() + + # Load V + v_head = gV[(batch_idx, 0, value_head_idx, None)] + v_sh = smem.allocate_tensor(cutlass.Float32, 128) + v_sh[tidx] = v_head[tidx].to(cutlass.Float32) + + # Registers: h_chunk + k_chunk (persistent) + qk_temp (reused for Q) + h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) # PERSISTENT K! + qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) + + k_base = warp_idx * 32 + + # Load K ONCE - keep for entire kernel + for i in cutlass.range_constexpr(32): + k_chunk[i] = k_sh[k_base + i] + + h_out = gH[(batch_idx, value_head_idx, None, None)] + o_head = gO[(batch_idx, 0, value_head_idx, None)] + + # ======================================================================== + # CHUNK 0 + # ======================================================================== + nvvm.cp_async_wait_group(1) + cute.arch.sync_threads() + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk0[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk0[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[warp_idx, lane_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[0, lane_idx] + + pred_sh[1, lane_idx] + + pred_sh[2, lane_idx] + + pred_sh[3, lane_idx] + ) + + v_val = (v_sh[lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + # Load Q for output computation + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[warp_idx, lane_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[0, lane_idx] + + out_sh[1, lane_idx] + + out_sh[2, lane_idx] + + out_sh[3, lane_idx] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk0, lane_idx, k_base) + if warp_idx == 0: + o_head[lane_idx] = out_final.to(cutlass.BFloat16) + + # ======================================================================== + # CHUNK 1 + # ======================================================================== + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + + load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) + nvvm.cp_async_commit_group() + + store_h_smem_to_gmem(h_sh_chunk0, h_out, tidx, 0) + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk1[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk1[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[warp_idx, lane_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[0, lane_idx] + + pred_sh[1, lane_idx] + + pred_sh[2, lane_idx] + + pred_sh[3, lane_idx] + ) + + v_val = (v_sh[32 + lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[warp_idx, lane_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[0, lane_idx] + + out_sh[1, lane_idx] + + out_sh[2, lane_idx] + + out_sh[3, lane_idx] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk1, lane_idx, k_base) + if warp_idx == 0: + o_head[32 + lane_idx] = out_final.to(cutlass.BFloat16) + + # ======================================================================== + # CHUNK 2 + # ======================================================================== + nvvm.cp_async_wait_group(1) + cute.arch.sync_threads() + + store_h_smem_to_gmem(h_sh_chunk1, h_out, tidx, 32) + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk2[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk2[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[warp_idx, lane_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[0, lane_idx] + + pred_sh[1, lane_idx] + + pred_sh[2, lane_idx] + + pred_sh[3, lane_idx] + ) + + v_val = (v_sh[64 + lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[warp_idx, lane_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[0, lane_idx] + + out_sh[1, lane_idx] + + out_sh[2, lane_idx] + + out_sh[3, lane_idx] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk2, lane_idx, k_base) + if warp_idx == 0: + o_head[64 + lane_idx] = out_final.to(cutlass.BFloat16) + + # ======================================================================== + # CHUNK 3 + # ======================================================================== + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + + store_h_smem_to_gmem(h_sh_chunk2, h_out, tidx, 64) + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk3[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk3[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[warp_idx, lane_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[0, lane_idx] + + pred_sh[1, lane_idx] + + pred_sh[2, lane_idx] + + pred_sh[3, lane_idx] + ) + + v_val = (v_sh[96 + lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[warp_idx, lane_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[0, lane_idx] + + out_sh[1, lane_idx] + + out_sh[2, lane_idx] + + out_sh[3, lane_idx] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk3, lane_idx, k_base) + if warp_idx == 0: + o_head[96 + lane_idx] = out_final.to(cutlass.BFloat16) + + cute.arch.sync_threads() + store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) + + +# ============================================================================== +# UNIFIED SEQLEN=2/3/4 MAIN KERNEL +# ============================================================================== + + +@cute.kernel +def gated_delta_rule_decode_kernel_seqlen234_unified( + gQ: cute.Tensor, # [B, T=2/3/4, H, K=128] + gK: cute.Tensor, # [B, T=2/3/4, H, K=128] + gV: cute.Tensor, # [B, T=2/3/4, HV, V=128] + ga: cute.Tensor, # [B, T=2/3/4, HV] + gb: cute.Tensor, # [B, T=2/3/4, HV] + gA_log: cute.Tensor, # [HV] + gdt_bias: cute.Tensor, # [HV] + gH: cute.Tensor, # [B, HV, V=128, K=128] - K-fast layout + gO: cute.Tensor, # [B, T=2/3/4, HV, V=128] + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + NUM_TOKENS: cutlass.Constexpr[int], # 2, 3, or 4 +): + """ + Unified kernel for Seqlen=2, Seqlen=3 and Seqlen=4 with compile-time specialization. + + Uses cutlass.Constexpr[int] NUM_TOKENS parameter to eliminate dead code paths: + - NUM_TOKENS=2: 4-slot reduce_sh, 2 Q/K/V buffers, 2 gates + - NUM_TOKENS=3: 6-slot reduce_sh, 3 Q/K/V buffers, 3 gates + - NUM_TOKENS=4: 8-slot reduce_sh, 4 Q/K/V buffers, 4 gates + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + HV = cutlass.Int32(gV.shape[2]) + H = cutlass.Int32(gQ.shape[2]) + + batch_idx = bidx // HV + value_head_idx = bidx % HV + query_head_idx = value_head_idx // (HV // H) + + warp_idx = tidx // 32 + lane_idx = tidx % 32 + k_base = warp_idx * 32 + + smem = utils.SmemAllocator() + + # SMEM Allocation - H chunks + h_sh_chunk0 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk1 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk2 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk3 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + + # Q/K buffers for tokens 0 and 1 (always needed for T>=2) + q_sh0 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh0 = smem.allocate_tensor(cutlass.Float32, 128) + q_sh1 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh1 = smem.allocate_tensor(cutlass.Float32, 128) + + # Q/K buffers for token 2 (only for NUM_TOKENS >= 3) + q_sh2 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh2 = smem.allocate_tensor(cutlass.Float32, 128) + + # Q/K buffers for token 3 (only for NUM_TOKENS=4) + q_sh3 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh3 = smem.allocate_tensor(cutlass.Float32, 128) + + # V buffers + v_sh0 = smem.allocate_tensor(cutlass.Float32, 128) + v_sh1 = smem.allocate_tensor(cutlass.Float32, 128) + v_sh2 = smem.allocate_tensor(cutlass.Float32, 128) + v_sh3 = smem.allocate_tensor(cutlass.Float32, 128) + + # Bank-conflict-free reduce_sh: [slot, lane_idx, warp_idx] + reduce_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128, 4, 1)) + ) + + # Register allocation + h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + kq_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + + # Gate computation - always compute gates 0, 1 (for T>=2) + A_log_val = gA_log[value_head_idx] + dt_bias_val = gdt_bias[value_head_idx] + + alpha0 = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + beta_raw0 = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + g_exp0, beta0 = compute_single_gate( + alpha0, beta_raw0, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + alpha1 = ga[(batch_idx, 1, value_head_idx)].to(cutlass.Float32) + beta_raw1 = gb[(batch_idx, 1, value_head_idx)].to(cutlass.Float32) + g_exp1, beta1 = compute_single_gate( + alpha1, beta_raw1, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Gate 2 - only for NUM_TOKENS >= 3 + g_exp2 = cutlass.Float32(0.0) + beta2 = cutlass.Float32(0.0) + if NUM_TOKENS >= 3: + alpha2 = ga[(batch_idx, 2, value_head_idx)].to(cutlass.Float32) + beta_raw2 = gb[(batch_idx, 2, value_head_idx)].to(cutlass.Float32) + g_exp2, beta2 = compute_single_gate( + alpha2, beta_raw2, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Gate 3 - only for NUM_TOKENS = 4 + g_exp3 = cutlass.Float32(0.0) + beta3 = cutlass.Float32(0.0) + if NUM_TOKENS == 4: + alpha3 = ga[(batch_idx, 3, value_head_idx)].to(cutlass.Float32) + beta_raw3 = gb[(batch_idx, 3, value_head_idx)].to(cutlass.Float32) + g_exp3, beta3 = compute_single_gate( + alpha3, beta_raw3, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Upfront H loading + h_global = gH[(batch_idx, value_head_idx, None, None)] + load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) + nvvm.cp_async_commit_group() + + # Q/K normalization - tokens 0, 1 always + q_head0 = gQ[(batch_idx, 0, query_head_idx, None)] + k_head0 = gK[(batch_idx, 0, query_head_idx, None)] + q_head1 = gQ[(batch_idx, 1, query_head_idx, None)] + k_head1 = gK[(batch_idx, 1, query_head_idx, None)] + + if warp_idx == 0: + normalize_and_store_qk_to_smem( + q_head0, k_head0, q_sh0, k_sh0, lane_idx, scale, eps + ) + if warp_idx == 1: + normalize_and_store_qk_to_smem( + q_head1, k_head1, q_sh1, k_sh1, lane_idx, scale, eps + ) + + # Token 2 Q/K normalization - only for NUM_TOKENS >= 3 + if NUM_TOKENS >= 3: + q_head2 = gQ[(batch_idx, 2, query_head_idx, None)] + k_head2 = gK[(batch_idx, 2, query_head_idx, None)] + if warp_idx == 2: + normalize_and_store_qk_to_smem( + q_head2, k_head2, q_sh2, k_sh2, lane_idx, scale, eps + ) + + # Token 3 Q/K normalization - only for NUM_TOKENS = 4 + if NUM_TOKENS == 4: + q_head3 = gQ[(batch_idx, 3, query_head_idx, None)] + k_head3 = gK[(batch_idx, 3, query_head_idx, None)] + if warp_idx == 3: + normalize_and_store_qk_to_smem( + q_head3, k_head3, q_sh3, k_sh3, lane_idx, scale, eps + ) + + cute.arch.sync_threads() + + # V loading - tokens 0, 1 always + v_head0 = gV[(batch_idx, 0, value_head_idx, None)] + v_head1 = gV[(batch_idx, 1, value_head_idx, None)] + load_v_to_smem(v_head0, v_sh0, tidx) + load_v_to_smem(v_head1, v_sh1, tidx) + + # Token 2 V loading - only for NUM_TOKENS >= 3 + if NUM_TOKENS >= 3: + v_head2 = gV[(batch_idx, 2, value_head_idx, None)] + load_v_to_smem(v_head2, v_sh2, tidx) + + # Token 3 V loading - only for NUM_TOKENS = 4 + if NUM_TOKENS == 4: + v_head3 = gV[(batch_idx, 3, value_head_idx, None)] + load_v_to_smem(v_head3, v_sh3, tidx) + + # Output pointers - tokens 0, 1 always + h_out = gH[(batch_idx, value_head_idx, None, None)] + o_head0 = gO[(batch_idx, 0, value_head_idx, None)] + o_head1 = gO[(batch_idx, 1, value_head_idx, None)] + + # Token 2 output pointer + o_head2 = o_head1 # Default for T=2 + if NUM_TOKENS >= 3: + o_head2 = gO[(batch_idx, 2, value_head_idx, None)] + + # Token 3 output pointer + o_head3 = o_head2 # Default for T=2,3 + if NUM_TOKENS == 4: + o_head3 = gO[(batch_idx, 3, value_head_idx, None)] + + # Process V-CHUNK 0 + nvvm.cp_async_wait_group(3) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk0, + h_sh_chunk0, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 0, + 0, + cutlass.Int32(0), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Process V-CHUNK 1 + nvvm.cp_async_wait_group(2) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk1, + h_sh_chunk0, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 32, + 0, + cutlass.Int32(1), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Process V-CHUNK 2 + nvvm.cp_async_wait_group(1) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk2, + h_sh_chunk1, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 64, + 32, + cutlass.Int32(1), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Process V-CHUNK 3 + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk3, + h_sh_chunk2, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 96, + 64, + cutlass.Int32(1), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Final H store + cute.arch.sync_threads() + store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) + + +# ============================================================================== +# LAUNCH WRAPPERS +# ============================================================================== + + +@cute.jit +def gated_delta_rule_launch_seqlen1( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen1( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +@cute.jit +def gated_delta_rule_launch_seqlen2( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen234_unified( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + 2, # NUM_TOKENS=2 + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +@cute.jit +def gated_delta_rule_launch_seqlen3( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen234_unified( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + 3, # NUM_TOKENS=3 + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +@cute.jit +def gated_delta_rule_launch_seqlen4( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen234_unified( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + 4, # NUM_TOKENS=4 + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +# ============================================================================== +# KERNEL CLASS +# ============================================================================== + + +class GatedDeltaRuleKernel: + """ + Gated Delta Rule Kernel for linear attention decode. + + This kernel implements the Gated Delta Rule mechanism supporting sequence + lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations. + + Key features: + - T=1: Persistent K in registers with aggressive pipelining + - T=2/3/4: Unified kernel with compile-time Constexpr specialization + - L2-normalized Q/K with configurable scale + - Gated exponential decay via softplus + - Bank-conflict-free cross-warp reductions + - Async H memory loading + + Args: + seq_len: Sequence length (1, 2, 3, or 4) + """ + + def __init__(self, seq_len: int): + assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}" + self.seq_len = seq_len + self._compiled_kernel = None + + def _get_launch_fn(self): + if self.seq_len == 1: + return gated_delta_rule_launch_seqlen1 + elif self.seq_len == 2: + return gated_delta_rule_launch_seqlen2 + elif self.seq_len == 3: + return gated_delta_rule_launch_seqlen3 + else: + return gated_delta_rule_launch_seqlen4 + + +# ============================================================================== +# PUBLIC API +# ============================================================================== + +_compiled_kernels = {} # Cache: (seqlen, batch_size) -> compiled kernel + + +def gated_delta_rule( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + q: Optional[torch.Tensor] = None, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + b: Optional[torch.Tensor] = None, + initial_state_source: Optional[torch.Tensor] = None, + initial_state_indices: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = True, + scale: Optional[float] = None, +) -> torch.Tensor: + """ + Gated Delta Rule linear attention kernel. + + Implements the Gated Delta Rule mechanism for decode-phase inference, + supporting sequence lengths T=1, T=2, T=3, T=4. + + Args: + A_log: Log decay parameter [HV] + a: Alpha gate input [B, T, HV] + dt_bias: Delta-t bias [HV] + softplus_beta: Softplus beta parameter (default: 1.0) + softplus_threshold: Softplus threshold (default: 20.0) + q: Query tensor [B, T, H, K] + k: Key tensor [B, T, H, K] + v: Value tensor [B, T, HV, V] + b: Beta gate input [B, T, HV] + initial_state_source: H state [B, HV, V, K] (K-fast layout), modified in-place + initial_state_indices: Not used (for compatibility) + use_qk_l2norm_in_kernel: Whether to L2-normalize Q/K in kernel (default: True) + scale: Optional attention scale (default: 1/sqrt(K)) + + Returns: + output: [B, T, HV, V] + + Example: + >>> B, T, H, K = 16, 1, 16, 128 + >>> HV, V = 32, 128 + >>> q = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) + >>> k = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) + >>> v = torch.randn(B, T, HV, V, device='cuda', dtype=torch.bfloat16) + >>> a = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) + >>> b = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) + >>> A_log = torch.randn(HV, device='cuda', dtype=torch.float32) + >>> dt_bias = torch.randn(HV, device='cuda', dtype=torch.float32) + >>> h_state = torch.randn(B, HV, V, K, device='cuda', dtype=torch.bfloat16) + >>> output = gated_delta_rule( + ... A_log, a, dt_bias, q=q, k=k, v=v, b=b, + ... initial_state_source=h_state + ... ) + """ + global _compiled_kernels + + B, T, H, K = q.shape + assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" + HV = v.shape[2] + V = v.shape[3] + + if scale is None: + scale = 1.0 / math.sqrt(K) + + output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + + q_ = from_dlpack(q, assumed_align=32) + k_ = from_dlpack(k, assumed_align=32) + v_ = from_dlpack(v, assumed_align=32) + a_ = from_dlpack(a, assumed_align=32) + b_ = from_dlpack(b, assumed_align=32) + A_log_ = from_dlpack(A_log, assumed_align=32) + dt_bias_ = from_dlpack(dt_bias, assumed_align=32) + h_ = from_dlpack(initial_state_source, assumed_align=32) + o_ = from_dlpack(output, assumed_align=32) + + scale_f32 = cutlass.Float32(scale) + softplus_beta_f32 = cutlass.Float32(softplus_beta) + softplus_threshold_f32 = cutlass.Float32(softplus_threshold) + eps_f32 = cutlass.Float32(1e-6) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Check cache + cache_key = (T, B) + if cache_key not in _compiled_kernels: + # Select and compile the appropriate kernel + if T == 1: + launch_fn = gated_delta_rule_launch_seqlen1 + elif T == 2: + launch_fn = gated_delta_rule_launch_seqlen2 + elif T == 3: + launch_fn = gated_delta_rule_launch_seqlen3 + else: # T == 4 + launch_fn = gated_delta_rule_launch_seqlen4 + + _compiled_kernels[cache_key] = cute.compile( + launch_fn, + q_, + k_, + v_, + a_, + b_, + A_log_, + dt_bias_, + h_, + o_, + scale_f32, + softplus_beta_f32, + softplus_threshold_f32, + eps_f32, + stream, + options="--generate-line-info", + ) + + # Execute + _compiled_kernels[cache_key]( + q_, + k_, + v_, + a_, + b_, + A_log_, + dt_bias_, + h_, + o_, + scale_f32, + softplus_beta_f32, + softplus_threshold_f32, + eps_f32, + stream, + ) + + return output From 3ac695d5a40bbc05165a7166c6834cb1c8c16672 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:09:01 -0800 Subject: [PATCH 02/11] Fix ruff B905 linter error Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- .../cute_dsl/benchmark_gated_delta_rule.py | 17 +++++++++++++---- flashinfer/cute_dsl/gated_delta_rule.py | 4 +--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/flashinfer/cute_dsl/benchmark_gated_delta_rule.py b/flashinfer/cute_dsl/benchmark_gated_delta_rule.py index 63be0ad1ff..22ec5ae7c5 100644 --- a/flashinfer/cute_dsl/benchmark_gated_delta_rule.py +++ b/flashinfer/cute_dsl/benchmark_gated_delta_rule.py @@ -14,10 +14,12 @@ def get_l2_cache_size(): return torch.cuda.get_device_properties(0).L2_cache_size -def benchmark(func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True): +def benchmark( + func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True +): """ Benchmark a kernel with L2 flushing and return median time in microseconds. - + Args: func: Function to benchmark num_iterations: Number of timed iterations @@ -55,7 +57,9 @@ def benchmark(func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_ma end_events[i].record() torch.cuda.synchronize() - times_us = [s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events)] + times_us = [ + s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events, strict=True) + ] return statistics.median(times_us) @@ -122,7 +126,12 @@ def run_kernel(): scale=inputs["scale"], ) - time_us = benchmark(run_kernel, num_iterations=num_iterations, flush_l2=True, use_dummy_matmul=True) + time_us = benchmark( + run_kernel, + num_iterations=num_iterations, + flush_l2=True, + use_dummy_matmul=True, + ) results[T][B] = time_us print(f" B={B:>3}: {time_us:>7.1f} us") diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py index dc00c0a5ec..9c86ec9141 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -131,9 +131,7 @@ def compute_single_gate( softplus_x = x g = -cute.math.exp(A_log_val) * softplus_x g_exp = cute.math.exp(g) - beta = cutlass.Float32(1.0) / ( - cutlass.Float32(1.0) + cute.math.exp(-beta_raw) - ) + beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw)) return g_exp, beta From e52e60a1eeb367442efd5572281ad5cc8c2f10c9 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:46:56 -0800 Subject: [PATCH 03/11] Add bf16 h-state reference, fastmath/rsqrt/L1-bypass optimizations to CuTe-DSL kernel, and improved CuTe-DSL benchmarking/testing Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- benchmarks/bench_gdn_decode.py | 375 +++++++++++++++++++++--- flashinfer/cute_dsl/gated_delta_rule.py | 22 +- tests/gdn/reference_delta_rule.py | 86 ++++-- tests/gdn/test_decode_delta_rule.py | 259 +++++++++++++++- 4 files changed, 661 insertions(+), 81 deletions(-) diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 9ec10b8fa5..211748114e 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -18,12 +18,21 @@ GDN Decode Benchmark This benchmark supports: -1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose +1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL 2. Single layout comparison: FlashInfer (CuTe DSL) vs Triton kernel (--compare) 3. MTP benchmark (--version mtp) +4. Improved CuTe-DSL multi-token benchmark (--version improved_cutedsl) for T=1,2,3,4 + +Kernels benchmarked: +- FlashInfer Pretranspose [B, HV, V, K] (V-major layout) +- FlashInfer Nontranspose [B, HV, K, V] (K-major layout) +- Triton Pretranspose [B, HV, V, K] +- Triton Nontranspose [B, HV, K, V] +- Improved CuTe-DSL Kernel [B, HV, V, K] (K-fast layout, supports T=1,2,3,4) + from flashinfer.cute_dsl.gated_delta_rule Usage: - # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL) python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 # Single layout comparison: FlashInfer vs Triton @@ -35,7 +44,10 @@ # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 - # Use Qwen3-Next preset + # Improved CuTe-DSL multi-token benchmark (T=1,2,3,4) + python benchmarks/bench_gdn_decode.py --version improved_cutedsl --batch-size 1 32 128 512 + + # Use Qwen3-Next preset (q=k=16, v=32, d=128) python benchmarks/bench_gdn_decode.py --preset qwen3-next --batch-size 1 32 128 512 """ @@ -50,6 +62,16 @@ ) from flashinfer.testing import bench_gpu_time +# Import the improved CuTe-DSL kernel for benchmarking (supports T=1,2,3,4) +try: + from flashinfer.cute_dsl.gated_delta_rule import ( + gated_delta_rule as improved_cutedsl_gdn, + ) + + IMPROVED_CUTEDSL_AVAILABLE = True +except ImportError: + IMPROVED_CUTEDSL_AVAILABLE = False + # ============================================================================ # Utility Functions @@ -102,6 +124,7 @@ def gdn_decode_bytes( dtype: torch.dtype, seq_len: int = 1, disable_state_update: bool = False, + state_dtype_bytes: int = 4, # 4 for FP32, 2 for BF16 ) -> int: """ Calculate memory bytes for GDN. @@ -110,8 +133,8 @@ def gdn_decode_bytes( Includes: - Q, K, V tensors (input): [B, T, H, K] - dtype - - State tensor (input/output): [B, HV, K, V] - float32 - - Intermediate states (MTP only): [B, T, HV, K, V] - float32 + - State tensor (input/output): [B, HV, K, V] - state_dtype_bytes (FP32=4 or BF16=2) + - Intermediate states (MTP only): [B, T, HV, K, V] - state_dtype_bytes - GDN parameters: A_log (float32), a (dtype), dt_bias (dtype), b (dtype) - Output tensor: [B, T, HV, V] - dtype @@ -129,15 +152,19 @@ def gdn_decode_bytes( # Output tensor: [B, T, HV, V] o_bytes = batch_size * seq_len * num_o_heads * head_size * elem_size - # State tensor (float32): [B, HV, K, V] + # State tensor: [B, HV, K, V] # If disable_state_update=True: only read initial state # If disable_state_update=False: read initial + write final state if disable_state_update: # Read only (e.g., MTP verify mode) - state_bytes = batch_size * num_sab_heads * head_size * head_size * 4 + state_bytes = ( + batch_size * num_sab_heads * head_size * head_size * state_dtype_bytes + ) else: # Read + write (e.g., normal decode) - state_bytes = 2 * batch_size * num_sab_heads * head_size * head_size * 4 + state_bytes = ( + 2 * batch_size * num_sab_heads * head_size * head_size * state_dtype_bytes + ) # GDN parameters # A_log: [HV] - float32 @@ -149,12 +176,17 @@ def gdn_decode_bytes( # b: [B, T, HV] - dtype b_bytes = batch_size * seq_len * num_sab_heads * elem_size - # Intermediate states (float32): [B, T, HV, K, V] - only for MTP (seq_len > 1) + # Intermediate states: [B, T, HV, K, V] - only for MTP (seq_len > 1) # Write all T steps of intermediate states intermediate_bytes = 0 if seq_len > 1: intermediate_bytes = ( - batch_size * seq_len * num_sab_heads * head_size * head_size * 4 + batch_size + * seq_len + * num_sab_heads + * head_size + * head_size + * state_dtype_bytes ) total_bytes = ( @@ -1800,6 +1832,49 @@ def verify_correctness_pretranspose( # ============================================================================ +def improved_cutedsl_gdn_wrapper( + q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4 + k: torch.Tensor, # [B, T, H_K, K] + v: torch.Tensor, # [B, T, HV, V] + state: torch.Tensor, # [B, HV, V, K] - K-fast layout (pretranspose) + A_log: torch.Tensor, # [HV] + a: torch.Tensor, # [B, T, HV] + dt_bias: torch.Tensor, # [HV] + b: torch.Tensor, # [B, T, HV] + scale: float, + output: torch.Tensor, # [B, T, HV, V] - unused, kernel returns output directly + use_qk_l2norm: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, +): + """ + Wrapper for improved CuTe-DSL GDN kernel. + Supports T=1,2,3,4 (sequence lengths up to 4). + Adapts the interface to match the benchmark's calling convention. + + Note: The kernel returns output directly, no copy needed. + """ + if not IMPROVED_CUTEDSL_AVAILABLE: + raise RuntimeError("Improved CuTe-DSL kernel is not available") + + # Call improved CuTe-DSL kernel directly - no wrapper overhead + # Kernel modifies state in-place and returns output tensor + return improved_cutedsl_gdn( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale, + ) + + def format_time(t): """Format time value, returning 'N/A' if None.""" return f"{t:>8.2f}" if t is not None else " N/A" @@ -1955,11 +2030,42 @@ def bench_all_layouts( results["tr_pretrans_us"] = None results["tr_nontrans_us"] = None + # ========== Improved CuTe-DSL Kernel (K-fast/pretranspose layout) ========== + if IMPROVED_CUTEDSL_AVAILABLE: + # Improved CuTe-DSL uses [B, HV, V, K] layout (K-fast, same as pretranspose) + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, # Improved CuTe-DSL uses BF16 state + device="cuda", + ) + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + try: + times = bench_gpu_time( + lambda: improved_cutedsl_gdn_wrapper( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + results["improved_cutedsl_us"] = np.median(times) * 1000 + except Exception as e: + results["improved_cutedsl_us"] = None + print(f" Improved CuTe-DSL kernel failed: {type(e).__name__}: {e}") + else: + results["improved_cutedsl_us"] = None + return results def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): - """Run benchmark comparing all layouts: FlashInfer/Triton x pretranspose/nontranspose.""" + """Run benchmark comparing all layouts: FlashInfer/Triton x pretranspose/nontranspose + CuTe-DSL.""" # Verify correctness first if requested if args.verify and TRITON_AVAILABLE: print("\n=== Correctness Verification ===") @@ -1995,24 +2101,24 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print(f" Nontranspose: ERROR - {type(e).__name__}") print() - print("\n" + "=" * 120) - print("GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose") + print("\n" + "=" * 160) + print("GDN Decode Benchmark (T=1): FlashInfer vs Triton vs Improved CuTe-DSL") print( f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " f"v_heads={args.num_v_heads}, head_size={args.head_size}, " f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" ) - print("=" * 120) + print("=" * 160) print() print( - f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | " - f"{'FI/TR-Pre':>9} {'FI/TR-Non':>9} | {'Pre/Non-FI':>10} {'Pre/Non-TR':>10}" + f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | {'ImpCuTe':>8} | " + f"{'FI/TR-Pre':>9} {'ImpCuTe/FI':>10} {'ImpCuTe/TR':>10}" ) print( - f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | " - f"{'speedup':>9} {'speedup':>9} | {'speedup':>10} {'speedup':>10}" + f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} | " + f"{'speedup':>9} {'speedup':>10} {'speedup':>10}" ) - print("-" * 120) + print("-" * 160) all_results = [] for batch_size in args.batch_size: @@ -2033,35 +2139,48 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): fi_non = result.get("fi_nontrans_us") tr_pre = result.get("tr_pretrans_us") tr_non = result.get("tr_nontrans_us") + imp_cute = result.get("improved_cutedsl_us") # FI/TR speedup (>1 means FI faster) fi_tr_pre = format_speedup(fi_pre, tr_pre) - fi_tr_non = format_speedup(fi_non, tr_non) - # Pre/Non speedup (>1 means pretranspose faster) - pre_non_fi = format_speedup(fi_pre, fi_non) - pre_non_tr = format_speedup(tr_pre, tr_non) + # Improved CuTe-DSL vs FI-PreTr speedup (>1 means Improved CuTe-DSL faster) + imp_cute_fi_speedup = format_speedup(imp_cute, fi_pre) + + # Improved CuTe-DSL vs TR-PreTr speedup (>1 means Improved CuTe-DSL faster) + imp_cute_tr_speedup = format_speedup(imp_cute, tr_pre) print( f"{batch_size:>6} | {format_time(fi_pre)} {format_time(fi_non)} | " - f"{format_time(tr_pre)} {format_time(tr_non)} | " - f"{fi_tr_pre} {fi_tr_non} | {pre_non_fi} {pre_non_tr}" + f"{format_time(tr_pre)} {format_time(tr_non)} | {format_time(imp_cute)} | " + f"{fi_tr_pre} {imp_cute_fi_speedup:>10} {imp_cute_tr_speedup:>10}" ) - print("-" * 120) + print("-" * 160) print() print("Legend:") print(" FI-PreTr = FlashInfer Pretranspose [B, HV, V, K]") print(" FI-NonTr = FlashInfer Nontranspose [B, HV, K, V]") print(" TR-PreTr = Triton Pretranspose [B, HV, V, K]") print(" TR-NonTr = Triton Nontranspose [B, HV, K, V]") + print( + " ImpCuTe = Improved CuTe-DSL Kernel [B, HV, V, K] (K-fast layout, supports T=1,2,3,4)" + ) print(" FI/TR speedup > 1.0 means FlashInfer is faster than Triton") - print(" Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose") + print( + " ImpCuTe/FI speedup > 1.0 means Improved CuTe-DSL is faster than FlashInfer Pretranspose" + ) + print( + " ImpCuTe/TR speedup > 1.0 means Improved CuTe-DSL is faster than Triton Pretranspose" + ) print() # Summary statistics fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")] tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")] + imp_cute_times = [ + r["improved_cutedsl_us"] for r in all_results if r.get("improved_cutedsl_us") + ] if fi_pre_times and tr_pre_times: speedups = [tr / fi for fi, tr in zip(fi_pre_times, tr_pre_times, strict=False)] @@ -2069,6 +2188,190 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) + if imp_cute_times and fi_pre_times and len(imp_cute_times) == len(fi_pre_times): + speedups = [ + fi / cute for cute, fi in zip(imp_cute_times, fi_pre_times, strict=False) + ] + print( + f"Improved CuTe-DSL vs FlashInfer (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + ) + + if imp_cute_times and tr_pre_times and len(imp_cute_times) == len(tr_pre_times): + speedups = [ + tr / cute for cute, tr in zip(imp_cute_times, tr_pre_times, strict=False) + ] + print( + f"Improved CuTe-DSL vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + ) + + +# ============================================================================ +# Improved CuTe-DSL Multi-Token Benchmark (T=1,2,3,4) +# ============================================================================ + + +def bench_improved_cutedsl( + batch_size: int, + seq_len: int, # T=1,2,3,4 + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + warmup_iters: int = 10, + bench_iters: int = 100, +): + """Benchmark Improved CuTe-DSL kernel for T=1,2,3,4.""" + if not IMPROVED_CUTEDSL_AVAILABLE: + raise RuntimeError("Improved CuTe-DSL kernel is not available") + + assert seq_len in [1, 2, 3, 4], ( + f"Improved CuTe-DSL supports T=1,2,3,4, got T={seq_len}" + ) + + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs + T = seq_len + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Initial state: [B, HV, V, K] (K-fast layout, BF16) + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + + # Pre-allocate output + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) + kernel_times_ms = bench_gpu_time( + lambda: improved_cutedsl_gdn_wrapper( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + + # Calculate metrics + kernel_median_ms = np.median(kernel_times_ms) + flops = gdn_decode_flops( + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len + ) + # Improved CuTe-DSL uses BF16 state (2 bytes), not FP32 (4 bytes) + bytes_accessed = gdn_decode_bytes( + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + dtype, + seq_len, + disable_state_update=False, + state_dtype_bytes=2, # BF16 state for improved CuTe-DSL + ) + + kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 + kernel_tb_per_sec = ( + bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 + ) + + return { + "batch_size": batch_size, + "seq_len": seq_len, + "kernel_median_us": kernel_median_ms * 1000, + "kernel_tflops": kernel_tflops, + "kernel_tb_per_sec": kernel_tb_per_sec, + } + + +def run_improved_cutedsl_benchmark(args, dtype, use_qk_l2norm): + """Run Improved CuTe-DSL benchmark for T=1,2,3,4.""" + if not IMPROVED_CUTEDSL_AVAILABLE: + print("Error: Improved CuTe-DSL kernel is not available.") + print("Make sure flashinfer.cute_dsl.gated_delta_rule is importable.") + return + + # Filter seq_len to only valid values (1,2,3,4) + valid_seq_lens = [t for t in args.seq_len if t in [1, 2, 3, 4]] + if not valid_seq_lens: + print("Error: --seq-len must include values from [1, 2, 3, 4]") + return + + print("\n" + "=" * 100) + print(f"Improved CuTe-DSL GDN Benchmark (T={valid_seq_lens})") + print( + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " + f"v_heads={args.num_v_heads}, head_size={args.head_size}, " + f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" + ) + print("=" * 100) + print() + print(f"{'batch':>6} {'T':>4} {'time(us)':>10} {'TFLOPS':>10} {'TB/s':>10}") + print("-" * 100) + + all_results = [] + for batch_size in args.batch_size: + for seq_len in valid_seq_lens: + try: + result = bench_improved_cutedsl( + batch_size=batch_size, + seq_len=seq_len, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + all_results.append(result) + + print( + f"{result['batch_size']:>6} {result['seq_len']:>4} " + f"{result['kernel_median_us']:>10.2f} " + f"{result['kernel_tflops']:>10.2f} " + f"{result['kernel_tb_per_sec']:>10.2f}" + ) + except Exception as e: + print( + f"{batch_size:>6} {seq_len:>4} {'ERROR':>10} - {type(e).__name__}: {e}" + ) + + print("-" * 100) + print() + + # Summary by T value + for t in valid_seq_lens: + t_results = [r for r in all_results if r["seq_len"] == t] + if t_results: + avg_time = np.mean([r["kernel_median_us"] for r in t_results]) + avg_tflops = np.mean([r["kernel_tflops"] for r in t_results]) + print( + f"T={t}: Average time={avg_time:.2f}us, Average TFLOPS={avg_tflops:.2f}" + ) + # ============================================================================ # Main Entry Points @@ -2357,7 +2660,7 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL) python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 # Single layout comparison: FlashInfer vs Triton (nontranspose) @@ -2371,6 +2674,9 @@ def main(): # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 + + # Improved CuTe-DSL benchmark (T=1,2,3,4) + python benchmarks/bench_gdn_decode.py --version improved_cutedsl --batch-size 1 32 128 512 """, ) parser.add_argument( @@ -2402,16 +2708,16 @@ def main(): parser.add_argument( "--version", type=str, - choices=["pretranspose", "nontranspose", "mtp", "all"], + choices=["pretranspose", "nontranspose", "mtp", "improved_cutedsl", "all"], default="nontranspose", - help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), or all", + help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), improved_cutedsl (T=1,2,3,4), or all", ) parser.add_argument( "--seq-len", type=int, nargs="+", - default=[2, 4, 8], - help="Sequence lengths for MTP benchmark (T > 1)", + default=[1, 2, 3, 4], + help="Sequence lengths: for MTP use T>1, for improved_cutedsl use T=1,2,3,4", ) parser.add_argument( "--cache-intermediate-states", @@ -2466,8 +2772,11 @@ def main(): run_comparison_benchmark(args, dtype, use_qk_l2norm) else: run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) + elif args.version == "improved_cutedsl": + # Improved CuTe-DSL benchmark for T=1,2,3,4 + run_improved_cutedsl_benchmark(args, dtype, use_qk_l2norm) else: - # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL) run_all_layouts_benchmark(args, dtype, use_qk_l2norm) diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py index 9c86ec9141..57a3077262 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -99,7 +99,11 @@ def load_h_chunk_async(h_sh_chunk, h_global, tidx, row_offset): val_layout = cute.make_layout((1, copy_elems)) atom_async_copy = cute.make_copy_atom( - cute.nvgpu.cpasync.CopyG2SOp(), cutlass.BFloat16, num_bits_per_copy=copy_bits + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + cutlass.BFloat16, + num_bits_per_copy=copy_bits, ) tiled_copy = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) thr_copy = tiled_copy.get_slice(tidx) @@ -124,14 +128,16 @@ def compute_single_gate( beta_x = softplus_beta * x softplus_x = cutlass.Float32(0.0) if beta_x <= softplus_threshold: - softplus_x = (cutlass.Float32(1.0) / softplus_beta) * cute.math.log( - cutlass.Float32(1.0) + cute.math.exp(beta_x) + softplus_x = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + cute.exp(beta_x, fastmath=True), fastmath=True ) else: softplus_x = x - g = -cute.math.exp(A_log_val) * softplus_x - g_exp = cute.math.exp(g) - beta = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-beta_raw)) + g = -cute.exp(A_log_val, fastmath=True) * softplus_x + g_exp = cute.exp(g, fastmath=True) + beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-beta_raw, fastmath=True) + ) return g_exp, beta @@ -173,8 +179,8 @@ def normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, k_sum_sq, offset=1 << i, mask=0xFFFFFFFF ) - q_norm = cutlass.Float32(1.0) / cute.math.sqrt(q_sum_sq + eps) - k_norm = cutlass.Float32(1.0) / cute.math.sqrt(k_sum_sq + eps) + q_norm = cute.rsqrt(q_sum_sq + eps, fastmath=True) + k_norm = cute.rsqrt(k_sum_sq + eps, fastmath=True) q_scale_factor = q_norm * scale for i in cutlass.range_constexpr(4): diff --git a/tests/gdn/reference_delta_rule.py b/tests/gdn/reference_delta_rule.py index 3fa10e0a2d..712a63cccb 100644 --- a/tests/gdn/reference_delta_rule.py +++ b/tests/gdn/reference_delta_rule.py @@ -137,6 +137,7 @@ def blockwise_linear_attention( | torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads decay_exponent_offset=0, kv_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: num_qo_heads = q.size(1) head_size = q.size(2) @@ -156,7 +157,7 @@ def blockwise_linear_attention( KVs = [] # FIXME: kernel debug only kv = torch.zeros( (len(seq_lens), num_qo_heads, head_size, head_size), - dtype=kv_dtype, + dtype=state_dtype, device=q.device, ) output = torch.zeros_like(q) @@ -166,7 +167,7 @@ def blockwise_linear_attention( seq_end = seq_offset[seq_idx + 1] blk_offset = seq_start carried_kv = torch.zeros( - (num_qo_heads, head_size, head_size), dtype=kv_dtype, device=q.device + (num_qo_heads, head_size, head_size), dtype=state_dtype, device=q.device ) while blk_offset < seq_end: is_full_block = seq_end - blk_offset >= block_size @@ -205,7 +206,10 @@ def blockwise_linear_attention( ) o_inter = ( - matmul(q_t.transpose(0, 1).to(kv_dtype) * Lq, carried_kv) + matmul( + q_t.transpose(0, 1).to(torch.float32) * Lq, + carried_kv.to(torch.float32), + ) .transpose(0, 1) .to(q.dtype) ) @@ -219,10 +223,10 @@ def blockwise_linear_attention( if (decay_factor == 1.0).all(): inc_kv = matmul( - k_t.transpose(0, 1).transpose(-2, -1).to(kv_dtype), - v_t.transpose(0, 1).to(kv_dtype), + k_t.transpose(0, 1).transpose(-2, -1).to(torch.float32), + v_t.transpose(0, 1).to(torch.float32), ) - carried_kv = carried_kv + inc_kv + carried_kv = (carried_kv.to(torch.float32) + inc_kv).to(state_dtype) else: Lk = LambdaK( decay_factor, @@ -232,11 +236,13 @@ def blockwise_linear_attention( offset=decay_exponent_offset, ) inc_kv = matmul( - (k_t.transpose(0, 1) * Lk).transpose(-2, -1).to(kv_dtype), - v_t.transpose(0, 1).to(kv_dtype), + (k_t.transpose(0, 1) * Lk).transpose(-2, -1).to(torch.float32), + v_t.transpose(0, 1).to(torch.float32), ) block_decay = decay_factor**valid_len - carried_kv = block_decay * carried_kv + inc_kv + carried_kv = (block_decay * carried_kv.to(torch.float32) + inc_kv).to( + state_dtype + ) KVs.append(carried_kv.clone()) blk_offset += block_size @@ -257,6 +263,7 @@ def delta_rule( beta: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] scale_factor=1.0, kv_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.float32, ): o = [] kv = [] @@ -297,7 +304,7 @@ def delta_rule( betas = beta[s] state_HKV = torch.zeros( - num_q_heads, head_size, head_size, dtype=kv_dtype, device=q.device + num_q_heads, head_size, head_size, dtype=state_dtype, device=q.device ) for i in range(seq_len): # var_DS where var is variable basename and DS is the dimensional semantics. @@ -311,14 +318,15 @@ def delta_rule( ### listed at the bottom of page3 of section 2.2 DELTA NETWORKS: LINEAR ATTENTION WITH DELTA RULE # state update rule, use the middle version for clearer dimensional semantics - old_state_HKV = alpha_H11 * state_HKV + # Read state in fp32, compute in fp32, store back in state_dtype + old_state_HKV = alpha_H11 * state_HKV.to(torch.float32) old_v_H1V = matmul(k_H1K, old_state_HKV) new_v_H1V = beta_H11 * v_H1V + (1 - beta_H11) * old_v_H1V state_remove = torch.einsum("htv,htk->hkv", old_v_H1V, k_H1K) state_update = torch.einsum("htv,htk->hkv", new_v_H1V, k_H1K) - state_HKV[:] = old_state_HKV - state_remove + state_update + state_HKV[:] = (old_state_HKV - state_remove + state_update).to(state_dtype) - o_H1V = scale_factor * matmul(q_H1Q, state_HKV) + o_H1V = scale_factor * matmul(q_H1Q, state_HKV.to(torch.float32)) o.append(o_H1V.squeeze(1)) kv.append(state_HKV.clone()) @@ -357,6 +365,7 @@ def blockwise_delta_rule( block_size: int = 32, scale_factor=1.0, kv_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.float32, # intermediate_outputs = None, # debug output ) -> torch.Tensor: total_seqlen = q.size(0) @@ -386,7 +395,7 @@ def blockwise_delta_rule( kv = torch.zeros( (len(seq_lens), num_sab_heads, head_size, head_size), - dtype=kv_dtype, + dtype=state_dtype, device=q.device, ) output = torch.zeros_like(q) @@ -396,7 +405,7 @@ def blockwise_delta_rule( seq_end = seq_offset[seq_idx + 1] blk_offset = seq_start state_HKV = torch.zeros( - (num_sab_heads, head_size, head_size), dtype=kv_dtype, device=q.device + (num_sab_heads, head_size, head_size), dtype=state_dtype, device=q.device ) while blk_offset < seq_end: is_full_block = seq_end - blk_offset >= block_size @@ -455,7 +464,9 @@ def blockwise_delta_rule( # new_v_HSV = matmul(T, (v_HSV - matmul(torch.exp(gamma_HS1) * k_HSK, state_HKV))) u_HSV = matmul(T, v_HSV) w_HSK = matmul(T, torch.exp(gamma_HS1) * k_HSK) - new_v_HSV = u_HSV - matmul(w_HSK.to(kv_dtype), state_HKV).to(u_HSV.dtype) + new_v_HSV = u_HSV - matmul( + w_HSK.to(torch.float32), state_HKV.to(torch.float32) + ).to(u_HSV.dtype) new_v_SHV = new_v_HSV.transpose(0, 1) # if intermediate_outputs is not None: @@ -468,7 +479,10 @@ def blockwise_delta_rule( # intermediate_outputs["new_v"].append(new_v_HSV.clone()) o_inter = ( - matmul(torch.exp(gamma_HS1) * q_HSQ.to(kv_dtype), state_HKV) + matmul( + torch.exp(gamma_HS1) * q_HSQ.to(torch.float32), + state_HKV.to(torch.float32), + ) .transpose(0, 1) .to(q.dtype) ) @@ -484,10 +498,12 @@ def blockwise_delta_rule( inc_HKV = matmul( (torch.exp(block_gamma - gamma_HS1) * k_HSK) .transpose(-2, -1) - .to(kv_dtype), - new_v_HSV.to(kv_dtype), + .to(torch.float32), + new_v_HSV.to(torch.float32), ) - state_HKV = torch.exp(block_gamma) * state_HKV + inc_HKV + state_HKV = ( + torch.exp(block_gamma) * state_HKV.to(torch.float32) + inc_HKV + ).to(state_dtype) blk_offset += block_size @@ -510,6 +526,7 @@ def decode_delta_rule( softplus_beta: float = 1.0, softplus_threshold: float = 20.0, use_l2_norm: bool = True, + state_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: """ Reference implementation for single-step decode with GDN formula. @@ -537,6 +554,7 @@ def decode_delta_rule( softplus_beta: Beta parameter for softplus activation softplus_threshold: Threshold for softplus numerical stability use_l2_norm: Whether to apply L2 normalization to q and k + state_dtype: Storage dtype for the hidden state (read in fp32, stored in this dtype) Returns: output: [B, num_heads, V] @@ -617,7 +635,7 @@ def decode_delta_rule( # ============================================ # Process each batch and head # ============================================ - new_state = torch.zeros(B, num_heads, K, V, device=device, dtype=dtype) + new_state = torch.zeros(B, num_heads, K, V, device=device, dtype=state_dtype) output = torch.zeros(B, num_heads, V, device=device, dtype=dtype) for b_idx in range(B): @@ -626,7 +644,9 @@ def decode_delta_rule( q_h = q[b_idx, h_idx] # [K] k_h = k[b_idx, h_idx] # [K] v_h = v[b_idx, h_idx] # [V] - h_state = state[b_idx, h_idx].clone() # [K, V] (matches Triton's [BK, BV]) + h_state = ( + state[b_idx, h_idx].clone().to(torch.float32) + ) # [K, V] read as fp32 # Get gating values for this batch and head g_val = g[b_idx, h_idx] # scalar @@ -673,8 +693,8 @@ def decode_delta_rule( # [K] @ [K, V] = [V] output[b_idx, h_idx] = q_h @ h_state - # Store updated state - new_state[b_idx, h_idx] = h_state + # Store updated state (cast back to state_dtype) + new_state[b_idx, h_idx] = h_state.to(state_dtype) return output, new_state @@ -694,6 +714,7 @@ def verify_delta_rule( softplus_threshold: float = 20.0, use_l2_norm: bool = True, cache_intermediate_states: bool = False, + state_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Reference implementation for multi-token (verify mode) delta rule. @@ -715,6 +736,7 @@ def verify_delta_rule( softplus_threshold: Threshold for softplus approximation use_l2_norm: Whether to apply L2 normalization cache_intermediate_states: Whether to cache state at each time step + state_dtype: Storage dtype for the hidden state (read in fp32, stored in this dtype) Returns: output: Output tensor [B, T, num_heads, V] @@ -779,11 +801,13 @@ def verify_delta_rule( # Initialize output and intermediate states output = torch.zeros(B, T, num_heads, V, dtype=torch.float32, device=q.device) - current_state = state.clone() # [B, num_heads, K, V] + current_state = state.clone().to( + state_dtype + ) # [B, num_heads, K, V] stored in state_dtype if cache_intermediate_states: intermediate_states = torch.zeros( - B, T, num_heads, K, V, dtype=torch.float32, device=q.device + B, T, num_heads, K, V, dtype=state_dtype, device=q.device ) else: intermediate_states = None @@ -802,7 +826,9 @@ def verify_delta_rule( q_h = q_t[b_idx, h_idx] # [K] k_h = k_t[b_idx, h_idx] # [K] v_h = v_t[b_idx, h_idx] # [V] - h_state = current_state[b_idx, h_idx].clone() # [K, V] + h_state = ( + current_state[b_idx, h_idx].clone().to(torch.float32) + ) # [K, V] read as fp32 g_val = g_t[b_idx, h_idx] beta_val = beta_t[b_idx, h_idx] @@ -825,11 +851,11 @@ def verify_delta_rule( # 5. Compute output: o = q^T @ h output[b_idx, t, h_idx] = q_h @ h_state # [K] @ [K, V] = [V] - # Update current state - current_state[b_idx, h_idx] = h_state + # Update current state (cast back to state_dtype) + current_state[b_idx, h_idx] = h_state.to(state_dtype) # Cache intermediate state if requested if cache_intermediate_states: - intermediate_states[b_idx, t, h_idx] = h_state + intermediate_states[b_idx, t, h_idx] = h_state.to(state_dtype) return output, current_state, intermediate_states diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 7d93b9ce98..ff59244465 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -41,6 +41,16 @@ ) from flashinfer.utils import get_compute_capability +# Import the improved CuTe-DSL kernel (supports T=1,2,3,4) +try: + from flashinfer.cute_dsl.gated_delta_rule import ( + gated_delta_rule as improved_cutedsl_gdn, + ) + + IMPROVED_CUTEDSL_AVAILABLE = True +except ImportError: + IMPROVED_CUTEDSL_AVAILABLE = False + def _skip_if_not_sm90_or_later(): """Skip test if not Hopper (SM90+) or Blackwell (SM100+) architecture.""" @@ -51,6 +61,7 @@ def _skip_if_not_sm90_or_later(): # ============================================================================ # Test decode kernel with pretranspose version ([B*HV, V, K]) +# Reference: fp32 h state (default); bf16 h state used only for improved_cutedsl_gdn. # ============================================================================ @@ -149,14 +160,12 @@ def _test_decode_kernel_pretranspose( # Remove T dimension for comparison: [B, 1, H, D] -> [B, H, D] our_o = our_o.squeeze(1) - # Reference implementation (remove T=1 dimension) - # Now passes raw GDN parameters, will compute g and beta internally - # Reference uses [B, HV, K, V] state (matches Triton) + # Reference: fp32 h state (default state_dtype) ref_o, ref_state = decode_delta_rule( q.squeeze(1).float(), # [B, 1, H, K] -> [B, H, K] k.squeeze(1).float(), v.squeeze(1).float(), - input_state_ref, # Use [B, HV, K, V] state for reference + input_state_ref, # [B, HV, K, V] A_log=A_log, a=a.squeeze(1), # Remove T dimension: [B, 1, HV] -> [B, HV] dt_bias=dt_bias, @@ -223,6 +232,7 @@ def test_decode_kernel_basic_pretranspose( # ============================================================================ # Test decode kernel with nontranspose version ([pool, HV, K, V]) +# Reference: fp32 h state (default). # ============================================================================ @@ -315,13 +325,12 @@ def _test_decode_kernel_nontranspose( # Remove T dimension for comparison: [B, 1, H, D] -> [B, H, D] our_o = our_o.squeeze(1) - # Reference implementation (remove T=1 dimension) - # Reference uses [B, HV, K, V] state (matches both Triton and nontranspose kernel) + # Reference: fp32 h state (default state_dtype) ref_o, ref_state = decode_delta_rule( q.squeeze(1).float(), # [B, 1, H, K] -> [B, H, K] k.squeeze(1).float(), v.squeeze(1).float(), - input_state, # Use [B, HV, K, V] state for reference + input_state, # [B, HV, K, V] A_log=A_log, a=a.squeeze(1), # Remove T dimension: [B, 1, HV] -> [B, HV] dt_bias=dt_bias, @@ -388,6 +397,7 @@ def test_decode_kernel_basic_nontranspose( # ============================================================================ # Test verify kernel with MTP version (Multiple Token Processing) +# Reference: fp32 h state (default). # ============================================================================ @@ -602,6 +612,213 @@ def test_verify_kernel_mtp( ) +# ============================================================================ +# Test improved CuTe-DSL kernel (supports T=1,2,3,4) +# Reference: bf16 h state only here (state_dtype=torch.bfloat16). Other kernels +# above use fp32 h state reference. +# ============================================================================ + + +def _test_improved_cutedsl_kernel( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + seq_len: int, # T=1,2,3,4 + scale: float, + alpha: bool, + beta: bool, + seed: int | None = None, +): + """Test improved CuTe-DSL kernel for T=1,2,3,4 with bf16 h state. + + Both kernel and reference use bf16 h state: reference runs with + state_dtype=torch.bfloat16 (read h as fp32, compute in fp32, store h in bf16) + so the comparison is apples-to-apples with the improved CuTe-DSL kernel. + """ + _skip_if_not_sm90_or_later() + + if not IMPROVED_CUTEDSL_AVAILABLE: + pytest.skip("Improved CuTe-DSL kernel not available") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + assert seq_len in [1, 2, 3, 4], ( + f"Improved CuTe-DSL supports T=1,2,3,4, got T={seq_len}" + ) + + # State and GDN parameters are based on num_v_heads (HV in kernel API) + num_sab_heads = num_v_heads + + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + # Generate inputs with T dimension + q = torch.randn(batch_size, seq_len, num_q_heads, head_size, dtype=dtype_torch) + k = torch.randn(batch_size, seq_len, num_k_heads, head_size, dtype=dtype_torch) + v = torch.randn(batch_size, seq_len, num_v_heads, head_size, dtype=dtype_torch) + + # NOTE: Do NOT pre-normalize K here. Both the kernel (use_qk_l2norm_in_kernel=True) + # and reference will apply L2 normalization internally after GQA expansion. + + # Improved CuTe-DSL kernel expects [B, HV, V, K] (K-fast layout) in BF16. + # Use the same bf16 initial state for both kernel and reference so we + # compare the bf16 h state path. + input_state_kernel = torch.randn( + batch_size, num_sab_heads, head_size, head_size, dtype=torch.bfloat16 + ) + + # Reference uses [B, HV, K, V] layout; same bf16 values as kernel. + input_state_ref_bf16 = input_state_kernel.transpose(-2, -1).contiguous() + + # Create GDN-specific parameters + # A_log: log decay parameter [HV] - must be float32 + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + + # dt_bias: decay bias [HV] - must be float32 for improved CuTe-DSL kernel + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + + # a: input-dependent decay [B, T, HV] + a = ( + torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + * 0.1 + ) + + # b: update gate input [B, T, HV] + if beta: + b_tensor = torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + else: + b_tensor = ( + torch.ones( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + * 10.0 + ) + + # Call improved CuTe-DSL kernel + our_state = input_state_kernel.clone() + our_o = improved_cutedsl_gdn( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b_tensor, + initial_state_source=our_state, + use_qk_l2norm_in_kernel=True, + scale=scale, + ) + + torch.cuda.synchronize() + + # Reference implementation with bf16 h state (state_dtype=torch.bfloat16): + # h is stored in bf16, read as fp32 for computation, written back in bf16. + ref_state = input_state_ref_bf16.clone() + ref_outputs = [] + + for t in range(seq_len): + ref_o_t, ref_state = decode_delta_rule( + q[:, t].float(), # [B, H, K] + k[:, t].float(), + v[:, t].float(), + ref_state, # [B, HV, K, V] bf16 + A_log=A_log, + a=a[:, t], # [B, HV] + dt_bias=dt_bias, + b=b_tensor[:, t], # [B, HV] + scale_factor=scale, + softplus_beta=1.0, + softplus_threshold=20.0, + use_l2_norm=True, + state_dtype=torch.bfloat16, # match kernel: h stored in bf16 + ) + ref_outputs.append(ref_o_t) + + # Stack reference outputs: [B, T, HV, V] + ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch) + + # Tolerances for bf16 h state comparison + atol_o = 0.001 + rtol_o = 0.005 + atol_kv = 0.005 + rtol_kv = 0.005 + + # Compare outputs + torch.testing.assert_close( + our_o.float(), + ref_o.float(), + atol=atol_o, + rtol=rtol_o, + msg=f"Output mismatch for improved CuTe-DSL kernel (B={batch_size}, T={seq_len})", + ) + + # Compare states: both in bf16 (kernel [B, HV, V, K], ref [B, HV, K, V]) + ref_state_transposed = ref_state.transpose(-2, -1).contiguous() + torch.testing.assert_close( + our_state.float(), + ref_state_transposed.float(), + atol=atol_kv, + rtol=rtol_kv, + msg=f"State mismatch for improved CuTe-DSL kernel (B={batch_size}, T={seq_len})", + ) + + print( + f"✓ Improved CuTe-DSL kernel test passed (batch={batch_size}, T={seq_len}, dtype={dtype}, h_state=bf16)" + ) + + +@pytest.mark.parametrize("beta", [True]) +@pytest.mark.parametrize("alpha", [True]) +@pytest.mark.parametrize("scale", ["auto"]) # Use 1/sqrt(K) like compare_flashinfer.py +@pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + "num_q_heads, num_k_heads, num_v_heads", + [(16, 16, 32)], +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_improved_cutedsl_kernel( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + seq_len: int, + scale: float | str, + alpha: bool, + beta: bool, + seed: int = int(os.environ.get("SEED", "0")), +): + scale_val = 1.0 / math.sqrt(head_size) if scale == "auto" else scale + _test_improved_cutedsl_kernel( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + seq_len, + scale_val, + alpha, + beta, + seed, + ) + + if __name__ == "__main__": print("Running smoke tests...") print("\n=== Testing PRETRANSPOSE version ===") @@ -648,15 +865,37 @@ def test_verify_kernel_mtp( seed=42, ) + print("\n=== Testing IMPROVED CuTe-DSL version (T=1,2,3,4) ===") + if IMPROVED_CUTEDSL_AVAILABLE: + for t in [1, 2, 3, 4]: + _test_improved_cutedsl_kernel( + dtype="bfloat16", + batch_size=4, + num_q_heads=16, + num_k_heads=16, + num_v_heads=32, + head_size=128, + seq_len=t, + scale=1.0, + alpha=True, + beta=True, + seed=42, + ) + else: + print("⚠ Improved CuTe-DSL kernel not available, skipping...") + print("\n✅ All smoke tests passed!") print("\nTo run full test suite:") print( - " PRETRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_pretranspose -v" + " PRETRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_pretranspose -v" + ) + print( + " NONTRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_nontranspose -v" ) print( - " NONTRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_nontranspose -v" + " MTP (VERIFY): pytest test_decode_delta_rule.py::test_verify_kernel_mtp -v" ) print( - " MTP (VERIFY): pytest test_decode_delta_rule.py::test_verify_kernel_mtp -v" + " IMPROVED CuTe-DSL: pytest test_decode_delta_rule.py::test_improved_cutedsl_kernel -v" ) print(" ALL: pytest test_decode_delta_rule.py -v") From 8eac4d8193a59bdbbb4d642d6951b81c2e79be68 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Tue, 10 Feb 2026 12:56:48 -0800 Subject: [PATCH 04/11] gated_delta_rule: use (32,4) smem layout for pred_sh/out_sh for coalescing and bank conflicts Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- flashinfer/cute_dsl/gated_delta_rule.py | 90 +++++++++++++------------ 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py index 57a3077262..5a4963a57d 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -755,8 +755,14 @@ def gated_delta_rule_decode_kernel_seqlen1( q_sh = smem.allocate_tensor(cutlass.Float32, 128) k_sh = smem.allocate_tensor(cutlass.Float32, 128) - pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) - out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + # pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + # out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + pred_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + out_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) h_global = gH[(batch_idx, value_head_idx, None, None)] @@ -823,13 +829,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) pred = pred + pred2 - pred_sh[warp_idx, lane_idx] = pred + pred_sh[lane_idx, warp_idx] = pred cute.arch.sync_threads() pred_final = ( - pred_sh[0, lane_idx] - + pred_sh[1, lane_idx] - + pred_sh[2, lane_idx] - + pred_sh[3, lane_idx] + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] ) v_val = (v_sh[lane_idx] - pred_final) * beta @@ -855,13 +861,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) out = out + out2 - out_sh[warp_idx, lane_idx] = out + out_sh[lane_idx, warp_idx] = out cute.arch.sync_threads() out_final = ( - out_sh[0, lane_idx] - + out_sh[1, lane_idx] - + out_sh[2, lane_idx] - + out_sh[3, lane_idx] + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] ) write_h_chunk_to_smem(h_chunk, h_sh_chunk0, lane_idx, k_base) @@ -900,13 +906,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) pred = pred + pred2 - pred_sh[warp_idx, lane_idx] = pred + pred_sh[lane_idx, warp_idx] = pred cute.arch.sync_threads() pred_final = ( - pred_sh[0, lane_idx] - + pred_sh[1, lane_idx] - + pred_sh[2, lane_idx] - + pred_sh[3, lane_idx] + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] ) v_val = (v_sh[32 + lane_idx] - pred_final) * beta @@ -931,13 +937,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) out = out + out2 - out_sh[warp_idx, lane_idx] = out + out_sh[lane_idx, warp_idx] = out cute.arch.sync_threads() out_final = ( - out_sh[0, lane_idx] - + out_sh[1, lane_idx] - + out_sh[2, lane_idx] - + out_sh[3, lane_idx] + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] ) write_h_chunk_to_smem(h_chunk, h_sh_chunk1, lane_idx, k_base) @@ -971,13 +977,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) pred = pred + pred2 - pred_sh[warp_idx, lane_idx] = pred + pred_sh[lane_idx, warp_idx] = pred cute.arch.sync_threads() pred_final = ( - pred_sh[0, lane_idx] - + pred_sh[1, lane_idx] - + pred_sh[2, lane_idx] - + pred_sh[3, lane_idx] + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] ) v_val = (v_sh[64 + lane_idx] - pred_final) * beta @@ -1002,13 +1008,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) out = out + out2 - out_sh[warp_idx, lane_idx] = out + out_sh[lane_idx, warp_idx] = out cute.arch.sync_threads() out_final = ( - out_sh[0, lane_idx] - + out_sh[1, lane_idx] - + out_sh[2, lane_idx] - + out_sh[3, lane_idx] + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] ) write_h_chunk_to_smem(h_chunk, h_sh_chunk2, lane_idx, k_base) @@ -1042,13 +1048,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) pred = pred + pred2 - pred_sh[warp_idx, lane_idx] = pred + pred_sh[lane_idx, warp_idx] = pred cute.arch.sync_threads() pred_final = ( - pred_sh[0, lane_idx] - + pred_sh[1, lane_idx] - + pred_sh[2, lane_idx] - + pred_sh[3, lane_idx] + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] ) v_val = (v_sh[96 + lane_idx] - pred_final) * beta @@ -1073,13 +1079,13 @@ def gated_delta_rule_decode_kernel_seqlen1( ) out = out + out2 - out_sh[warp_idx, lane_idx] = out + out_sh[lane_idx, warp_idx] = out cute.arch.sync_threads() out_final = ( - out_sh[0, lane_idx] - + out_sh[1, lane_idx] - + out_sh[2, lane_idx] - + out_sh[3, lane_idx] + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] ) write_h_chunk_to_smem(h_chunk, h_sh_chunk3, lane_idx, k_base) From ef0034f33ddca62d3ae8a70624934830e868466f Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Tue, 10 Feb 2026 13:13:49 -0800 Subject: [PATCH 05/11] gated_delta_rule: add LowBS-1 kernel for T=1, BS<=4 Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- flashinfer/cute_dsl/gated_delta_rule.py | 213 +++++++++++++++++++++++- 1 file changed, 212 insertions(+), 1 deletion(-) diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py index 5a4963a57d..63cdd78b87 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -1525,6 +1525,215 @@ def gated_delta_rule_launch_seqlen1( ) +# ============================================================================== +# LOW-BS SEQLEN=1 KERNEL - 1 V-CHUNK PER CTA (T=1, BS<=4) +# ============================================================================== + + +@cute.kernel +def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( + gQ: cute.Tensor, + gK: cute.Tensor, + gV: cute.Tensor, + ga: cute.Tensor, + gb: cute.Tensor, + gA_log: cute.Tensor, + gdt_bias: cute.Tensor, + gH: cute.Tensor, + gO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, +): + """ + Seqlen=1 kernel with 1 V-chunk (32 V rows) per CTA. + For T=1, batch_size <= 4: more CTAs per batch*head for better SM utilization. + Grid: batch_idx * HV * 4 + value_head_idx * 4 + v_chunk_idx (0..3). + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + HV = cutlass.Int32(gV.shape[2]) + H = cutlass.Int32(gQ.shape[2]) + + batch_idx = bidx // (HV * 4) + remainder = bidx % (HV * 4) + value_head_idx = remainder // 4 + v_chunk_idx = remainder % 4 + + query_head_idx = value_head_idx // (HV // H) + v_row_base = v_chunk_idx * 32 + + smem = utils.SmemAllocator() + + alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + A_log_val = gA_log[value_head_idx] + dt_bias_val = gdt_bias[value_head_idx] + g_exp, beta = compute_single_gate( + alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + h_sh_chunk = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + + q_sh = smem.allocate_tensor(cutlass.Float32, 128) + k_sh = smem.allocate_tensor(cutlass.Float32, 128) + + pred_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + out_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + + h_global = gH[(batch_idx, value_head_idx, None, None)] + + load_h_chunk_async(h_sh_chunk, h_global, tidx, v_row_base) + nvvm.cp_async_commit_group() + + q_head = gQ[(batch_idx, 0, query_head_idx, None)] + k_head = gK[(batch_idx, 0, query_head_idx, None)] + + warp_idx = tidx // 32 + lane_idx = tidx % 32 + + if warp_idx == 0: + normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) + + cute.arch.sync_threads() + + v_head = gV[(batch_idx, 0, value_head_idx, None)] + v_sh = smem.allocate_tensor(cutlass.Float32, 32) + if tidx < 32: + v_sh[tidx] = v_head[v_row_base + tidx].to(cutlass.Float32) + + h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) + + k_base = warp_idx * 32 + + for i in cutlass.range_constexpr(32): + k_chunk[i] = k_sh[k_base + i] + + h_out = gH[(batch_idx, value_head_idx, None, None)] + o_head = gO[(batch_idx, 0, value_head_idx, None)] + + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[lane_idx, warp_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] + ) + + v_val = (v_sh[lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[lane_idx, warp_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk, lane_idx, k_base) + if warp_idx == 0: + o_head[v_row_base + lane_idx] = out_final.to(cutlass.BFloat16) + + cute.arch.sync_threads() + store_h_smem_to_gmem(h_sh_chunk, h_out, tidx, v_row_base) + + +@cute.jit +def gated_delta_rule_launch_seqlen1_lowBS_1chunk( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + """Launch LowBS-1 kernel: 4 CTAs per (batch, value_head).""" + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + ).launch( + grid=[batch_size * HV * 4, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + @cute.jit def gated_delta_rule_launch_seqlen2( mQ: cute.Tensor, @@ -1786,7 +1995,9 @@ def gated_delta_rule( cache_key = (T, B) if cache_key not in _compiled_kernels: # Select and compile the appropriate kernel - if T == 1: + if T == 1 and B <= 4: + launch_fn = gated_delta_rule_launch_seqlen1_lowBS_1chunk + elif T == 1: launch_fn = gated_delta_rule_launch_seqlen1 elif T == 2: launch_fn = gated_delta_rule_launch_seqlen2 From 4fffb167431b7a8ce6e8cf06bd3723f2641edab0 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Tue, 10 Feb 2026 13:17:20 -0800 Subject: [PATCH 06/11] gated_delta_rule: enable tvm-ffi Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- flashinfer/cute_dsl/gated_delta_rule.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py index 63cdd78b87..245684f69a 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -1974,15 +1974,15 @@ def gated_delta_rule( output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) - q_ = from_dlpack(q, assumed_align=32) - k_ = from_dlpack(k, assumed_align=32) - v_ = from_dlpack(v, assumed_align=32) - a_ = from_dlpack(a, assumed_align=32) - b_ = from_dlpack(b, assumed_align=32) - A_log_ = from_dlpack(A_log, assumed_align=32) - dt_bias_ = from_dlpack(dt_bias, assumed_align=32) - h_ = from_dlpack(initial_state_source, assumed_align=32) - o_ = from_dlpack(output, assumed_align=32) + q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) + k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) + v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) + a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) + b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) + A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) + dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) + h_ = from_dlpack(initial_state_source, assumed_align=32, enable_tvm_ffi=True) + o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) scale_f32 = cutlass.Float32(scale) softplus_beta_f32 = cutlass.Float32(softplus_beta) @@ -2022,7 +2022,7 @@ def gated_delta_rule( softplus_threshold_f32, eps_f32, stream, - options="--generate-line-info", + options="--enable-tvm-ffi --generate-line-info", ) # Execute From 8fa0e9be519216dee340e32e2732f269e26e5ef1 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Tue, 10 Feb 2026 14:35:33 -0800 Subject: [PATCH 07/11] gdn_decode: add gdn_decode_klast_bf16_state backend and rename from improved_cutedsl - gdn_decode.py: optional backend for pretranspose when bf16 state, T<=4, K=V=128; dispatch to cute_dsl gated_delta_rule. - bench_gdn_decode.py: rename improved_cutedsl to gdn_decode_klast_bf16_state (--version, wrapper, result keys). - test_decode_delta_rule.py: same rename; add test_pretranspose_api_uses_gdn_decode_klast_bf16_state for API dispatch. Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- benchmarks/bench_gdn_decode.py | 152 +++++++++++++++------------- flashinfer/gdn_decode.py | 68 +++++++++++-- tests/gdn/test_decode_delta_rule.py | 150 ++++++++++++++++++++++----- 3 files changed, 268 insertions(+), 102 deletions(-) diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 211748114e..2d2e72bafe 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -18,21 +18,21 @@ GDN Decode Benchmark This benchmark supports: -1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL +1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state 2. Single layout comparison: FlashInfer (CuTe DSL) vs Triton kernel (--compare) 3. MTP benchmark (--version mtp) -4. Improved CuTe-DSL multi-token benchmark (--version improved_cutedsl) for T=1,2,3,4 +4. gdn_decode_klast_bf16_state benchmark (--version gdn_decode_klast_bf16_state) for T=1,2,3,4 Kernels benchmarked: - FlashInfer Pretranspose [B, HV, V, K] (V-major layout) - FlashInfer Nontranspose [B, HV, K, V] (K-major layout) - Triton Pretranspose [B, HV, V, K] - Triton Nontranspose [B, HV, K, V] -- Improved CuTe-DSL Kernel [B, HV, V, K] (K-fast layout, supports T=1,2,3,4) +- gdn_decode_klast_bf16_state [B, HV, V, K] (K-fast layout, T=1..4, bf16 state) from flashinfer.cute_dsl.gated_delta_rule Usage: - # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL) + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 # Single layout comparison: FlashInfer vs Triton @@ -44,8 +44,8 @@ # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 - # Improved CuTe-DSL multi-token benchmark (T=1,2,3,4) - python benchmarks/bench_gdn_decode.py --version improved_cutedsl --batch-size 1 32 128 512 + # gdn_decode_klast_bf16_state benchmark (T=1,2,3,4) + python benchmarks/bench_gdn_decode.py --version gdn_decode_klast_bf16_state --batch-size 1 32 128 512 # Use Qwen3-Next preset (q=k=16, v=32, d=128) python benchmarks/bench_gdn_decode.py --preset qwen3-next --batch-size 1 32 128 512 @@ -62,15 +62,15 @@ ) from flashinfer.testing import bench_gpu_time -# Import the improved CuTe-DSL kernel for benchmarking (supports T=1,2,3,4) +# Import the gdn_decode_klast_bf16_state kernel for benchmarking (T=1..4, bf16 state, K-last) try: from flashinfer.cute_dsl.gated_delta_rule import ( - gated_delta_rule as improved_cutedsl_gdn, + gated_delta_rule as gdn_decode_klast_bf16_state, ) - IMPROVED_CUTEDSL_AVAILABLE = True + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True except ImportError: - IMPROVED_CUTEDSL_AVAILABLE = False + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False # ============================================================================ @@ -1832,7 +1832,7 @@ def verify_correctness_pretranspose( # ============================================================================ -def improved_cutedsl_gdn_wrapper( +def gdn_decode_klast_bf16_state_wrapper( q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4 k: torch.Tensor, # [B, T, H_K, K] v: torch.Tensor, # [B, T, HV, V] @@ -1848,18 +1848,18 @@ def improved_cutedsl_gdn_wrapper( softplus_threshold: float = 20.0, ): """ - Wrapper for improved CuTe-DSL GDN kernel. + Wrapper for gdn_decode_klast_bf16_state GDN kernel. Supports T=1,2,3,4 (sequence lengths up to 4). Adapts the interface to match the benchmark's calling convention. Note: The kernel returns output directly, no copy needed. """ - if not IMPROVED_CUTEDSL_AVAILABLE: - raise RuntimeError("Improved CuTe-DSL kernel is not available") + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + raise RuntimeError("gdn_decode_klast_bf16_state kernel is not available") - # Call improved CuTe-DSL kernel directly - no wrapper overhead + # Call gdn_decode_klast_bf16_state kernel directly - no wrapper overhead # Kernel modifies state in-place and returns output tensor - return improved_cutedsl_gdn( + return gdn_decode_klast_bf16_state( A_log=A_log, a=a, dt_bias=dt_bias, @@ -2030,15 +2030,15 @@ def bench_all_layouts( results["tr_pretrans_us"] = None results["tr_nontrans_us"] = None - # ========== Improved CuTe-DSL Kernel (K-fast/pretranspose layout) ========== - if IMPROVED_CUTEDSL_AVAILABLE: - # Improved CuTe-DSL uses [B, HV, V, K] layout (K-fast, same as pretranspose) + # ========== gdn_decode_klast_bf16_state Kernel (K-fast/pretranspose layout) ========== + if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + # gdn_decode_klast_bf16_state uses [B, HV, V, K] layout (K-fast, same as pretranspose) state = torch.randn( batch_size, num_sab_heads, head_size, head_size, - dtype=torch.bfloat16, # Improved CuTe-DSL uses BF16 state + dtype=torch.bfloat16, # gdn_decode_klast_bf16_state uses BF16 state device="cuda", ) output = torch.empty( @@ -2047,19 +2047,21 @@ def bench_all_layouts( try: times = bench_gpu_time( - lambda: improved_cutedsl_gdn_wrapper( + lambda: gdn_decode_klast_bf16_state_wrapper( q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm ), enable_cupti=True, dry_run_iters=warmup_iters, repeat_iters=bench_iters, ) - results["improved_cutedsl_us"] = np.median(times) * 1000 + results["gdn_decode_klast_bf16_state_us"] = np.median(times) * 1000 except Exception as e: - results["improved_cutedsl_us"] = None - print(f" Improved CuTe-DSL kernel failed: {type(e).__name__}: {e}") + results["gdn_decode_klast_bf16_state_us"] = None + print( + f" gdn_decode_klast_bf16_state kernel failed: {type(e).__name__}: {e}" + ) else: - results["improved_cutedsl_us"] = None + results["gdn_decode_klast_bf16_state_us"] = None return results @@ -2102,7 +2104,9 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print() print("\n" + "=" * 160) - print("GDN Decode Benchmark (T=1): FlashInfer vs Triton vs Improved CuTe-DSL") + print( + "GDN Decode Benchmark (T=1): FlashInfer vs Triton vs gdn_decode_klast_bf16_state" + ) print( f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " f"v_heads={args.num_v_heads}, head_size={args.head_size}, " @@ -2111,8 +2115,8 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print("=" * 160) print() print( - f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | {'ImpCuTe':>8} | " - f"{'FI/TR-Pre':>9} {'ImpCuTe/FI':>10} {'ImpCuTe/TR':>10}" + f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | {'KlastBf16':>9} | " + f"{'FI/TR-Pre':>9} {'KlastBf16/FI':>11} {'KlastBf16/TR':>11}" ) print( f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} | " @@ -2139,21 +2143,21 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): fi_non = result.get("fi_nontrans_us") tr_pre = result.get("tr_pretrans_us") tr_non = result.get("tr_nontrans_us") - imp_cute = result.get("improved_cutedsl_us") + klast_bf16_us = result.get("gdn_decode_klast_bf16_state_us") # FI/TR speedup (>1 means FI faster) fi_tr_pre = format_speedup(fi_pre, tr_pre) - # Improved CuTe-DSL vs FI-PreTr speedup (>1 means Improved CuTe-DSL faster) - imp_cute_fi_speedup = format_speedup(imp_cute, fi_pre) + # gdn_decode_klast_bf16_state vs FI-PreTr speedup (>1 means klast_bf16 faster) + klast_bf16_fi_speedup = format_speedup(klast_bf16_us, fi_pre) - # Improved CuTe-DSL vs TR-PreTr speedup (>1 means Improved CuTe-DSL faster) - imp_cute_tr_speedup = format_speedup(imp_cute, tr_pre) + # gdn_decode_klast_bf16_state vs TR-PreTr speedup (>1 means klast_bf16 faster) + klast_bf16_tr_speedup = format_speedup(klast_bf16_us, tr_pre) print( f"{batch_size:>6} | {format_time(fi_pre)} {format_time(fi_non)} | " - f"{format_time(tr_pre)} {format_time(tr_non)} | {format_time(imp_cute)} | " - f"{fi_tr_pre} {imp_cute_fi_speedup:>10} {imp_cute_tr_speedup:>10}" + f"{format_time(tr_pre)} {format_time(tr_non)} | {format_time(klast_bf16_us)} | " + f"{fi_tr_pre} {klast_bf16_fi_speedup:>10} {klast_bf16_tr_speedup:>10}" ) print("-" * 160) @@ -2164,22 +2168,24 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print(" TR-PreTr = Triton Pretranspose [B, HV, V, K]") print(" TR-NonTr = Triton Nontranspose [B, HV, K, V]") print( - " ImpCuTe = Improved CuTe-DSL Kernel [B, HV, V, K] (K-fast layout, supports T=1,2,3,4)" + " KlastBf16 = gdn_decode_klast_bf16_state [B, HV, V, K] (K-fast layout, T=1..4, bf16 state)" ) print(" FI/TR speedup > 1.0 means FlashInfer is faster than Triton") print( - " ImpCuTe/FI speedup > 1.0 means Improved CuTe-DSL is faster than FlashInfer Pretranspose" + " KlastBf16/FI speedup > 1.0 means gdn_decode_klast_bf16_state is faster than FlashInfer Pretranspose" ) print( - " ImpCuTe/TR speedup > 1.0 means Improved CuTe-DSL is faster than Triton Pretranspose" + " KlastBf16/TR speedup > 1.0 means gdn_decode_klast_bf16_state is faster than Triton Pretranspose" ) print() # Summary statistics fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")] tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")] - imp_cute_times = [ - r["improved_cutedsl_us"] for r in all_results if r.get("improved_cutedsl_us") + klast_bf16_times = [ + r["gdn_decode_klast_bf16_state_us"] + for r in all_results + if r.get("gdn_decode_klast_bf16_state_us") ] if fi_pre_times and tr_pre_times: @@ -2188,29 +2194,29 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) - if imp_cute_times and fi_pre_times and len(imp_cute_times) == len(fi_pre_times): + if klast_bf16_times and fi_pre_times and len(klast_bf16_times) == len(fi_pre_times): speedups = [ - fi / cute for cute, fi in zip(imp_cute_times, fi_pre_times, strict=False) + fi / t for t, fi in zip(klast_bf16_times, fi_pre_times, strict=False) ] print( - f"Improved CuTe-DSL vs FlashInfer (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + f"gdn_decode_klast_bf16_state vs FlashInfer (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) - if imp_cute_times and tr_pre_times and len(imp_cute_times) == len(tr_pre_times): + if klast_bf16_times and tr_pre_times and len(klast_bf16_times) == len(tr_pre_times): speedups = [ - tr / cute for cute, tr in zip(imp_cute_times, tr_pre_times, strict=False) + tr / t for t, tr in zip(klast_bf16_times, tr_pre_times, strict=False) ] print( - f"Improved CuTe-DSL vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + f"gdn_decode_klast_bf16_state vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) # ============================================================================ -# Improved CuTe-DSL Multi-Token Benchmark (T=1,2,3,4) +# gdn_decode_klast_bf16_state Multi-Token Benchmark (T=1,2,3,4) # ============================================================================ -def bench_improved_cutedsl( +def bench_gdn_decode_klast_bf16_state( batch_size: int, seq_len: int, # T=1,2,3,4 num_q_heads: int, @@ -2222,12 +2228,12 @@ def bench_improved_cutedsl( warmup_iters: int = 10, bench_iters: int = 100, ): - """Benchmark Improved CuTe-DSL kernel for T=1,2,3,4.""" - if not IMPROVED_CUTEDSL_AVAILABLE: - raise RuntimeError("Improved CuTe-DSL kernel is not available") + """Benchmark gdn_decode_klast_bf16_state kernel for T=1,2,3,4.""" + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + raise RuntimeError("gdn_decode_klast_bf16_state kernel is not available") assert seq_len in [1, 2, 3, 4], ( - f"Improved CuTe-DSL supports T=1,2,3,4, got T={seq_len}" + f"gdn_decode_klast_bf16_state supports T=1,2,3,4, got T={seq_len}" ) num_o_heads = max(num_q_heads, num_v_heads) @@ -2265,7 +2271,7 @@ def bench_improved_cutedsl( # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) kernel_times_ms = bench_gpu_time( - lambda: improved_cutedsl_gdn_wrapper( + lambda: gdn_decode_klast_bf16_state_wrapper( q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm ), enable_cupti=True, @@ -2278,7 +2284,7 @@ def bench_improved_cutedsl( flops = gdn_decode_flops( batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len ) - # Improved CuTe-DSL uses BF16 state (2 bytes), not FP32 (4 bytes) + # gdn_decode_klast_bf16_state uses BF16 state (2 bytes), not FP32 (4 bytes) bytes_accessed = gdn_decode_bytes( batch_size, num_q_heads, @@ -2288,7 +2294,7 @@ def bench_improved_cutedsl( dtype, seq_len, disable_state_update=False, - state_dtype_bytes=2, # BF16 state for improved CuTe-DSL + state_dtype_bytes=2, # BF16 state for gdn_decode_klast_bf16_state ) kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 @@ -2305,10 +2311,10 @@ def bench_improved_cutedsl( } -def run_improved_cutedsl_benchmark(args, dtype, use_qk_l2norm): - """Run Improved CuTe-DSL benchmark for T=1,2,3,4.""" - if not IMPROVED_CUTEDSL_AVAILABLE: - print("Error: Improved CuTe-DSL kernel is not available.") +def run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm): + """Run gdn_decode_klast_bf16_state benchmark for T=1,2,3,4.""" + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + print("Error: gdn_decode_klast_bf16_state kernel is not available.") print("Make sure flashinfer.cute_dsl.gated_delta_rule is importable.") return @@ -2319,7 +2325,7 @@ def run_improved_cutedsl_benchmark(args, dtype, use_qk_l2norm): return print("\n" + "=" * 100) - print(f"Improved CuTe-DSL GDN Benchmark (T={valid_seq_lens})") + print(f"gdn_decode_klast_bf16_state GDN Benchmark (T={valid_seq_lens})") print( f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " f"v_heads={args.num_v_heads}, head_size={args.head_size}, " @@ -2334,7 +2340,7 @@ def run_improved_cutedsl_benchmark(args, dtype, use_qk_l2norm): for batch_size in args.batch_size: for seq_len in valid_seq_lens: try: - result = bench_improved_cutedsl( + result = bench_gdn_decode_klast_bf16_state( batch_size=batch_size, seq_len=seq_len, num_q_heads=args.num_q_heads, @@ -2675,8 +2681,8 @@ def main(): # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 - # Improved CuTe-DSL benchmark (T=1,2,3,4) - python benchmarks/bench_gdn_decode.py --version improved_cutedsl --batch-size 1 32 128 512 + # gdn_decode_klast_bf16_state benchmark (T=1,2,3,4) + python benchmarks/bench_gdn_decode.py --version gdn_decode_klast_bf16_state --batch-size 1 32 128 512 """, ) parser.add_argument( @@ -2708,16 +2714,22 @@ def main(): parser.add_argument( "--version", type=str, - choices=["pretranspose", "nontranspose", "mtp", "improved_cutedsl", "all"], + choices=[ + "pretranspose", + "nontranspose", + "mtp", + "gdn_decode_klast_bf16_state", + "all", + ], default="nontranspose", - help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), improved_cutedsl (T=1,2,3,4), or all", + help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), gdn_decode_klast_bf16_state (T=1..4, bf16 state, K-last), or all", ) parser.add_argument( "--seq-len", type=int, nargs="+", default=[1, 2, 3, 4], - help="Sequence lengths: for MTP use T>1, for improved_cutedsl use T=1,2,3,4", + help="Sequence lengths: for MTP use T>1, for gdn_decode_klast_bf16_state use T=1,2,3,4", ) parser.add_argument( "--cache-intermediate-states", @@ -2772,11 +2784,11 @@ def main(): run_comparison_benchmark(args, dtype, use_qk_l2norm) else: run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) - elif args.version == "improved_cutedsl": - # Improved CuTe-DSL benchmark for T=1,2,3,4 - run_improved_cutedsl_benchmark(args, dtype, use_qk_l2norm) + elif args.version == "gdn_decode_klast_bf16_state": + # gdn_decode_klast_bf16_state benchmark for T=1,2,3,4 + run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm) else: - # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL) + # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) run_all_layouts_benchmark(args, dtype, use_qk_l2norm) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index e64c231686..7bcb0517cb 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -60,6 +60,18 @@ def flashinfer_api(func): # type: ignore[misc] return func +# GDN decode K-last bf16 state kernel (T=1..4, bf16 state, K-last layout) - optional backend +try: + from .cute_dsl.gated_delta_rule import ( + gated_delta_rule as _gated_delta_rule_gdn_decode_klast_bf16_state, + ) + + _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True +except ImportError: + _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False + _gated_delta_rule_gdn_decode_klast_bf16_state = None + + # ============================================================================ # Global configuration for PRETRANSPOSE version ([B*HV, V, K]) # ============================================================================ @@ -952,8 +964,9 @@ def gated_delta_rule_decode_pretranspose( v (torch.Tensor): Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. state (torch.Tensor): - Current state of shape ``[B, HV, V, K]`` (v-major layout). - Must be float32. Will be updated in-place. + Current state of shape ``[B, HV, V, K]`` (v-major / K-last layout). + Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend + when T in 1..4 and K=V=128. Will be updated in-place. A_log (torch.Tensor): Log decay parameter of shape ``[HV]``. Must be float32. a (torch.Tensor): @@ -978,19 +991,61 @@ def gated_delta_rule_decode_pretranspose( Note: - Requires SM90 (Hopper) architecture - State is updated in-place - - K and V must be multiples of 4 for vectorized loads - - State layout is v-major: [B, HV, V, K] + - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 + and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used. + - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes B, T, H, K = q.shape - assert T == 1, f"Decode only supports T=1, got T={T}" _, _, HV, V = v.shape - # Validate state shape + # Validate state shape (Qwen-style K-last: [B, HV, V, K]) assert state.shape == (B, HV, V, K), ( f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}" ) + # Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V=128 + use_gdn_decode_klast_bf16_state = ( + _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE + and state.dtype == torch.bfloat16 + and T in (1, 2, 3, 4) + and K == 128 + and V == 128 + ) + if use_gdn_decode_klast_bf16_state: + assert q.dtype in (torch.float16, torch.bfloat16), ( + f"q must be float16/bfloat16, got {q.dtype}" + ) + assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" + scale_val = K**-0.5 if scale is None else scale + out = _gated_delta_rule_gdn_decode_klast_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale_val, + ) + output_provided = output is not None + target_dtype = output.dtype if output_provided else q.dtype + if output is not None: + output.copy_(out) + else: + output = out + if output.dtype != target_dtype: + output = output.to(target_dtype) + return output, state + + # Legacy path: T=1 only, float32 state + assert T == 1, f"Decode only supports T=1, got T={T}" + assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" + # Validate K and V constraints assert K >= 128, f"K must be at least 128, got K={K}" assert V >= 128, f"V must be at least 128, got V={V}" @@ -1002,7 +1057,6 @@ def gated_delta_rule_decode_pretranspose( assert q.dtype in (torch.float16, torch.bfloat16), ( f"q must be float16/bfloat16, got {q.dtype}" ) - assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" # Set default scale diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index ff59244465..b234ff4e93 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -41,15 +41,15 @@ ) from flashinfer.utils import get_compute_capability -# Import the improved CuTe-DSL kernel (supports T=1,2,3,4) +# Import the gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last layout) try: from flashinfer.cute_dsl.gated_delta_rule import ( - gated_delta_rule as improved_cutedsl_gdn, + gated_delta_rule as gdn_decode_klast_bf16_state, ) - IMPROVED_CUTEDSL_AVAILABLE = True + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True except ImportError: - IMPROVED_CUTEDSL_AVAILABLE = False + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False def _skip_if_not_sm90_or_later(): @@ -61,7 +61,7 @@ def _skip_if_not_sm90_or_later(): # ============================================================================ # Test decode kernel with pretranspose version ([B*HV, V, K]) -# Reference: fp32 h state (default); bf16 h state used only for improved_cutedsl_gdn. +# Reference: fp32 h state (default); bf16 h state used only for gdn_decode_klast_bf16_state. # ============================================================================ @@ -613,13 +613,13 @@ def test_verify_kernel_mtp( # ============================================================================ -# Test improved CuTe-DSL kernel (supports T=1,2,3,4) +# Test gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last) # Reference: bf16 h state only here (state_dtype=torch.bfloat16). Other kernels # above use fp32 h state reference. # ============================================================================ -def _test_improved_cutedsl_kernel( +def _test_gdn_decode_klast_bf16_state_kernel( dtype: str, batch_size: int, num_q_heads: int, @@ -632,23 +632,23 @@ def _test_improved_cutedsl_kernel( beta: bool, seed: int | None = None, ): - """Test improved CuTe-DSL kernel for T=1,2,3,4 with bf16 h state. + """Test gdn_decode_klast_bf16_state kernel for T=1,2,3,4 with bf16 h state. Both kernel and reference use bf16 h state: reference runs with state_dtype=torch.bfloat16 (read h as fp32, compute in fp32, store h in bf16) - so the comparison is apples-to-apples with the improved CuTe-DSL kernel. + so the comparison is apples-to-apples with the gdn_decode_klast_bf16_state kernel. """ _skip_if_not_sm90_or_later() - if not IMPROVED_CUTEDSL_AVAILABLE: - pytest.skip("Improved CuTe-DSL kernel not available") + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + pytest.skip("gdn_decode_klast_bf16_state kernel not available") random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) assert seq_len in [1, 2, 3, 4], ( - f"Improved CuTe-DSL supports T=1,2,3,4, got T={seq_len}" + f"gdn_decode_klast_bf16_state supports T=1,2,3,4, got T={seq_len}" ) # State and GDN parameters are based on num_v_heads (HV in kernel API) @@ -666,7 +666,7 @@ def _test_improved_cutedsl_kernel( # NOTE: Do NOT pre-normalize K here. Both the kernel (use_qk_l2norm_in_kernel=True) # and reference will apply L2 normalization internally after GQA expansion. - # Improved CuTe-DSL kernel expects [B, HV, V, K] (K-fast layout) in BF16. + # gdn_decode_klast_bf16_state kernel expects [B, HV, V, K] (K-fast layout) in BF16. # Use the same bf16 initial state for both kernel and reference so we # compare the bf16 h state path. input_state_kernel = torch.randn( @@ -680,7 +680,7 @@ def _test_improved_cutedsl_kernel( # A_log: log decay parameter [HV] - must be float32 A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 - # dt_bias: decay bias [HV] - must be float32 for improved CuTe-DSL kernel + # dt_bias: decay bias [HV] - must be float32 for gdn_decode_klast_bf16_state kernel dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # a: input-dependent decay [B, T, HV] @@ -704,9 +704,9 @@ def _test_improved_cutedsl_kernel( * 10.0 ) - # Call improved CuTe-DSL kernel + # Call gdn_decode_klast_bf16_state kernel our_state = input_state_kernel.clone() - our_o = improved_cutedsl_gdn( + our_o = gdn_decode_klast_bf16_state( A_log=A_log, a=a, dt_bias=dt_bias, @@ -761,7 +761,7 @@ def _test_improved_cutedsl_kernel( ref_o.float(), atol=atol_o, rtol=rtol_o, - msg=f"Output mismatch for improved CuTe-DSL kernel (B={batch_size}, T={seq_len})", + msg=f"Output mismatch for gdn_decode_klast_bf16_state kernel (B={batch_size}, T={seq_len})", ) # Compare states: both in bf16 (kernel [B, HV, V, K], ref [B, HV, K, V]) @@ -771,11 +771,11 @@ def _test_improved_cutedsl_kernel( ref_state_transposed.float(), atol=atol_kv, rtol=rtol_kv, - msg=f"State mismatch for improved CuTe-DSL kernel (B={batch_size}, T={seq_len})", + msg=f"State mismatch for gdn_decode_klast_bf16_state kernel (B={batch_size}, T={seq_len})", ) print( - f"✓ Improved CuTe-DSL kernel test passed (batch={batch_size}, T={seq_len}, dtype={dtype}, h_state=bf16)" + f"✓ gdn_decode_klast_bf16_state kernel test passed (batch={batch_size}, T={seq_len}, dtype={dtype}, h_state=bf16)" ) @@ -790,7 +790,7 @@ def _test_improved_cutedsl_kernel( ) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_improved_cutedsl_kernel( +def test_gdn_decode_klast_bf16_state_kernel( dtype: str, num_q_heads: int, num_k_heads: int, @@ -804,7 +804,7 @@ def test_improved_cutedsl_kernel( seed: int = int(os.environ.get("SEED", "0")), ): scale_val = 1.0 / math.sqrt(head_size) if scale == "auto" else scale - _test_improved_cutedsl_kernel( + _test_gdn_decode_klast_bf16_state_kernel( dtype, batch_size, num_q_heads, @@ -819,6 +819,106 @@ def test_improved_cutedsl_kernel( ) +@pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + "num_q_heads, num_k_heads, num_v_heads", + [(16, 16, 32)], +) +def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + seq_len: int, + seed: int = int(os.environ.get("SEED", "0")), +): + """Verify gated_delta_rule_decode_pretranspose dispatches to gdn_decode_klast_bf16_state when state is bf16 and T<=4, K=V=128. + + Calls the API with bf16 state and checks output/state match the direct gdn_decode_klast_bf16_state call. + """ + _skip_if_not_sm90_or_later() + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + pytest.skip("gdn_decode_klast_bf16_state kernel not available") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + dtype = torch.bfloat16 + device = torch.device("cuda") + scale = 1.0 / math.sqrt(head_size) + num_sab_heads = num_v_heads + + q = torch.randn( + batch_size, seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + k = torch.randn( + batch_size, seq_len, num_k_heads, head_size, dtype=dtype, device=device + ) + v = torch.randn( + batch_size, seq_len, num_v_heads, head_size, dtype=dtype, device=device + ) + a = ( + torch.randn(batch_size, seq_len, num_sab_heads, dtype=dtype, device=device) + * 0.1 + ) + b_tensor = torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype, device=device + ) + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + + # State [B, HV, V, K] in bf16 (Qwen-style K-last) so API uses improved backend + state_api = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, + device=device, + ) + state_direct = state_api.clone() + + # Via API (should dispatch to gdn_decode_klast_bf16_state) + out_api, state_api = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=state_api, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b_tensor, + scale=scale, + use_qk_l2norm=True, + ) + + # Direct improved kernel + out_direct = gdn_decode_klast_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b_tensor, + initial_state_source=state_direct, + use_qk_l2norm_in_kernel=True, + scale=scale, + ) + + torch.testing.assert_close(out_api, out_direct, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(state_api, state_direct, atol=1e-2, rtol=1e-2) + print( + f"✓ API gdn_decode_klast_bf16_state backend verified (batch={batch_size}, T={seq_len})" + ) + + if __name__ == "__main__": print("Running smoke tests...") print("\n=== Testing PRETRANSPOSE version ===") @@ -866,9 +966,9 @@ def test_improved_cutedsl_kernel( ) print("\n=== Testing IMPROVED CuTe-DSL version (T=1,2,3,4) ===") - if IMPROVED_CUTEDSL_AVAILABLE: + if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: for t in [1, 2, 3, 4]: - _test_improved_cutedsl_kernel( + _test_gdn_decode_klast_bf16_state_kernel( dtype="bfloat16", batch_size=4, num_q_heads=16, @@ -882,7 +982,7 @@ def test_improved_cutedsl_kernel( seed=42, ) else: - print("⚠ Improved CuTe-DSL kernel not available, skipping...") + print("⚠ gdn_decode_klast_bf16_state kernel not available, skipping...") print("\n✅ All smoke tests passed!") print("\nTo run full test suite:") @@ -896,6 +996,6 @@ def test_improved_cutedsl_kernel( " MTP (VERIFY): pytest test_decode_delta_rule.py::test_verify_kernel_mtp -v" ) print( - " IMPROVED CuTe-DSL: pytest test_decode_delta_rule.py::test_improved_cutedsl_kernel -v" + " gdn_decode_klast_bf16_state: pytest test_decode_delta_rule.py::test_gdn_decode_klast_bf16_state_kernel -v" ) print(" ALL: pytest test_decode_delta_rule.py -v") From 8d1e6b9086007f3742b5064b05780842e3c5f838 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:05:12 -0800 Subject: [PATCH 08/11] Refactor: Consolidate dtype parameters in reference implementation - Consolidate to single state_dtype parameter across all reference functions - Remove duplicate kv_dtype parameter from blockwise_linear_attention(), delta_rule(), and blockwise_delta_rule() - Update test_prefill_delta_rule.py to use state_dtype consistently - Remove benchmark_gated_delta_rule.py from git tracking (keep locally) - Add to .gitignore for local development use only Co-Authored-By: Claude Sonnet 4.5 --- .gitignore | 1 + .../cute_dsl/benchmark_gated_delta_rule.py | 167 ------------------ tests/gdn/reference_delta_rule.py | 3 - tests/gdn/test_prefill_delta_rule.py | 4 +- 4 files changed, 3 insertions(+), 172 deletions(-) delete mode 100644 flashinfer/cute_dsl/benchmark_gated_delta_rule.py diff --git a/.gitignore b/.gitignore index c06b8448ca..1a89e54605 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ csrc/aot_default_additional_params.h # Microbenchmark files microbenchmark/ +flashinfer/cute_dsl/benchmark_gated_delta_rule.py # vscode .vscode/ diff --git a/flashinfer/cute_dsl/benchmark_gated_delta_rule.py b/flashinfer/cute_dsl/benchmark_gated_delta_rule.py deleted file mode 100644 index 22ec5ae7c5..0000000000 --- a/flashinfer/cute_dsl/benchmark_gated_delta_rule.py +++ /dev/null @@ -1,167 +0,0 @@ -""" -Benchmark: Gated Delta Rule CuTe-DSL Kernel - -Simple benchmark showing duration across batch sizes and sequence lengths (T=1,2,3,4). -""" - -import math -import statistics -import torch - - -def get_l2_cache_size(): - """Get L2 cache size in bytes for the current GPU.""" - return torch.cuda.get_device_properties(0).L2_cache_size - - -def benchmark( - func, num_iterations=100, n_warmup=10, flush_l2=True, use_dummy_matmul=True -): - """ - Benchmark a kernel with L2 flushing and return median time in microseconds. - - Args: - func: Function to benchmark - num_iterations: Number of timed iterations - n_warmup: Number of warmup iterations - flush_l2: Whether to flush L2 cache before each iteration - use_dummy_matmul: Whether to use dummy matmul for short-lived kernels - """ - l2_size = get_l2_cache_size() - cache_flush = torch.empty(l2_size, dtype=torch.uint8, device="cuda") - - # Dummy matmul for short-lived kernels (fills GPU pipeline so CUDA events record properly) - if use_dummy_matmul: - A = torch.randn(4096, 4096, dtype=torch.float32, device="cuda") - B = torch.randn(4096, 4096, dtype=torch.float32, device="cuda") - _ = A @ B # Warm up cuBLAS - - # Warmup - for _ in range(n_warmup): - if flush_l2: - cache_flush.zero_() - func() - torch.cuda.synchronize() - - # Benchmark - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iterations)] - - for i in range(num_iterations): - if flush_l2: - cache_flush.zero_() - if use_dummy_matmul: - _ = A @ B # Dummy work to ensure events record properly for short kernels - start_events[i].record() - func() - end_events[i].record() - - torch.cuda.synchronize() - times_us = [ - s.elapsed_time(e) * 1000 for s, e in zip(start_events, end_events, strict=True) - ] - return statistics.median(times_us) - - -def create_inputs(B, T, H=16, HV=32, K=128, V=128): - """Create test inputs.""" - return { - "q": torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16), - "k": torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16), - "v": torch.randn(B, T, HV, V, device="cuda", dtype=torch.bfloat16), - "a": torch.randn(B, T, HV, device="cuda", dtype=torch.bfloat16) * 0.1, - "b": torch.randn(B, T, HV, device="cuda", dtype=torch.bfloat16), - "A_log": torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1, - "dt_bias": torch.randn(HV, device="cuda", dtype=torch.float32) * 0.1, - "state": torch.randn(B, HV, V, K, device="cuda", dtype=torch.bfloat16), - "scale": 1.0 / math.sqrt(K), - } - - -def main(): - from gated_delta_rule import gated_delta_rule - - print("=" * 70) - print("Gated Delta Rule CuTe-DSL Kernel Benchmark") - print("Config: H=16, HV=32, K=128, V=128, bfloat16") - print("=" * 70) - - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] - seqlens = [1, 2, 3, 4] - num_iterations = 100 - - # Results storage - results = {T: {} for T in seqlens} - - # Benchmark each configuration - for T in seqlens: - print(f"\nCompiling and benchmarking T={T}...") - for B in batch_sizes: - inputs = create_inputs(B, T) - state = inputs["state"].clone() - - # Warmup / compile - _ = gated_delta_rule( - A_log=inputs["A_log"], - a=inputs["a"], - dt_bias=inputs["dt_bias"], - q=inputs["q"], - k=inputs["k"], - v=inputs["v"], - b=inputs["b"], - initial_state_source=state, - scale=inputs["scale"], - ) - - def run_kernel(): - return gated_delta_rule( - A_log=inputs["A_log"], - a=inputs["a"], - dt_bias=inputs["dt_bias"], - q=inputs["q"], - k=inputs["k"], - v=inputs["v"], - b=inputs["b"], - initial_state_source=state, - scale=inputs["scale"], - ) - - time_us = benchmark( - run_kernel, - num_iterations=num_iterations, - flush_l2=True, - use_dummy_matmul=True, - ) - results[T][B] = time_us - print(f" B={B:>3}: {time_us:>7.1f} us") - - # Summary table - print("\n" + "=" * 70) - print("SUMMARY: Duration (us) by Batch Size and Sequence Length") - print("=" * 70) - - # Header - header = f"{'B':>6} |" - for T in seqlens: - header += f" T={T} |" - print(header) - print("-" * 70) - - # Data rows - for B in batch_sizes: - row = f"{B:>6} |" - for T in seqlens: - row += f" {results[T][B]:>7.1f} |" - print(row) - - print("-" * 70) - - # Averages - print("\nAverage duration per T:") - for T in seqlens: - avg = sum(results[T].values()) / len(results[T]) - print(f" T={T}: {avg:.1f} us") - - -if __name__ == "__main__": - main() diff --git a/tests/gdn/reference_delta_rule.py b/tests/gdn/reference_delta_rule.py index 712a63cccb..7296610bbd 100644 --- a/tests/gdn/reference_delta_rule.py +++ b/tests/gdn/reference_delta_rule.py @@ -136,7 +136,6 @@ def blockwise_linear_attention( decay_factor: float | torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads decay_exponent_offset=0, - kv_dtype: torch.dtype = torch.float32, state_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: num_qo_heads = q.size(1) @@ -262,7 +261,6 @@ def delta_rule( alpha: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] beta: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] scale_factor=1.0, - kv_dtype: torch.dtype = torch.float32, state_dtype: torch.dtype = torch.float32, ): o = [] @@ -364,7 +362,6 @@ def blockwise_delta_rule( beta: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] block_size: int = 32, scale_factor=1.0, - kv_dtype: torch.dtype = torch.float32, state_dtype: torch.dtype = torch.float32, # intermediate_outputs = None, # debug output ) -> torch.Tensor: diff --git a/tests/gdn/test_prefill_delta_rule.py b/tests/gdn/test_prefill_delta_rule.py index f2fd06cbce..8bc87f2ef6 100644 --- a/tests/gdn/test_prefill_delta_rule.py +++ b/tests/gdn/test_prefill_delta_rule.py @@ -117,7 +117,7 @@ def _test_prefill_kernel( scale_factor=scale, alpha=alpha, beta=beta, - kv_dtype=torch.float32, + state_dtype=torch.float32, ) ref_o = ref_o.to(q.dtype) ref_state = ref_state.to(kv_dtype) @@ -364,7 +364,7 @@ def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2): scale_factor=scale, alpha=alpha, beta=beta, - kv_dtype=torch.float32, + state_dtype=torch.float32, ) ref_o = ref_o.to(q.dtype) ref_state = ref_state.to(kv_dtype) From 696dca90b78ca66252cdaeebc3cd5a356dd38395 Mon Sep 17 00:00:00 2001 From: ameynaik-hub <212485788+ameynaik-hub@users.noreply.github.com> Date: Mon, 16 Feb 2026 09:51:27 -0800 Subject: [PATCH 09/11] Update benchmarks/bench_gdn_decode.py Co-authored-by: Zihao Ye --- benchmarks/bench_gdn_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 2d2e72bafe..8a9d9f9394 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -1836,7 +1836,7 @@ def gdn_decode_klast_bf16_state_wrapper( q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4 k: torch.Tensor, # [B, T, H_K, K] v: torch.Tensor, # [B, T, HV, V] - state: torch.Tensor, # [B, HV, V, K] - K-fast layout (pretranspose) + state: torch.Tensor, # [B, HV, V, K] - K-last layout (pretranspose) A_log: torch.Tensor, # [HV] a: torch.Tensor, # [B, T, HV] dt_bias: torch.Tensor, # [HV] From 4c2c1b57625e4683c1efa33c814c78f48f112c31 Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Mon, 16 Feb 2026 10:58:23 -0800 Subject: [PATCH 10/11] fix: Add parameter validation and improve cache key in gated_delta_rule Add validation for required tensor parameters to fail early with clear error messages. Expand cache key to include all shape dimensions (H, HV, K, V) to prevent incorrect kernel reuse when shapes change. Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- flashinfer/cute_dsl/gated_delta_rule.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/cute_dsl/gated_delta_rule.py index 245684f69a..131350ab7d 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/cute_dsl/gated_delta_rule.py @@ -1964,6 +1964,18 @@ def gated_delta_rule( """ global _compiled_kernels + # Validate required Optional parameters + if q is None: + raise ValueError("q (query tensor) is required") + if k is None: + raise ValueError("k (key tensor) is required") + if v is None: + raise ValueError("v (value tensor) is required") + if b is None: + raise ValueError("b (beta gate tensor) is required") + if initial_state_source is None: + raise ValueError("initial_state_source (H state tensor) is required") + B, T, H, K = q.shape assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" HV = v.shape[2] @@ -1991,8 +2003,8 @@ def gated_delta_rule( stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Check cache - cache_key = (T, B) + # Check cache - include all shape dimensions to avoid incorrect reuse + cache_key = (T, B, H, HV, K, V) if cache_key not in _compiled_kernels: # Select and compile the appropriate kernel if T == 1 and B <= 4: From 7c5b004a8e3478866a6355448adad169360627da Mon Sep 17 00:00:00 2001 From: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:12:57 -0800 Subject: [PATCH 11/11] refactor: Move GDN CuTe DSL kernel to dedicated gdn_kernels module Relocate gated_delta_rule.py from flashinfer/cute_dsl/ to a new flashinfer/gdn_kernels/ module to improve code organization and clarify the kernel's domain-specific purpose. Changes: - Create flashinfer/gdn_kernels/ module for GDN-specific CuTe DSL kernels - Rename gated_delta_rule.py to gdn_decode_bf16_state.py for clarity (indicates BF16 hidden state variant) - Update all 3 import sites to use new path: - flashinfer/gdn_decode.py - benchmarks/bench_gdn_decode.py - tests/gdn/test_decode_delta_rule.py - Add module __init__.py with proper re-exports - Avoid namespace conflict with existing gdn_decode.py file The flashinfer/cute_dsl/ directory remains for cross-cutting CuTe DSL utilities (RMSNorm, FP4, GEMM+AllReduce, etc.). All 32 GDN decode CuTe DSL tests pass successfully. Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> --- benchmarks/bench_gdn_decode.py | 2 +- flashinfer/gdn_decode.py | 2 +- flashinfer/gdn_kernels/__init__.py | 33 +++++++++++++++++++ .../gdn_decode_bf16_state.py} | 8 +++-- tests/gdn/test_decode_delta_rule.py | 2 +- 5 files changed, 42 insertions(+), 5 deletions(-) create mode 100644 flashinfer/gdn_kernels/__init__.py rename flashinfer/{cute_dsl/gated_delta_rule.py => gdn_kernels/gdn_decode_bf16_state.py} (99%) diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 8a9d9f9394..c1d18f810f 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -64,7 +64,7 @@ # Import the gdn_decode_klast_bf16_state kernel for benchmarking (T=1..4, bf16 state, K-last) try: - from flashinfer.cute_dsl.gated_delta_rule import ( + from flashinfer.gdn_kernels.gdn_decode_bf16_state import ( gated_delta_rule as gdn_decode_klast_bf16_state, ) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 7bcb0517cb..26c742e839 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -62,7 +62,7 @@ def flashinfer_api(func): # type: ignore[misc] # GDN decode K-last bf16 state kernel (T=1..4, bf16 state, K-last layout) - optional backend try: - from .cute_dsl.gated_delta_rule import ( + from .gdn_kernels.gdn_decode_bf16_state import ( gated_delta_rule as _gated_delta_rule_gdn_decode_klast_bf16_state, ) diff --git a/flashinfer/gdn_kernels/__init__.py b/flashinfer/gdn_kernels/__init__.py new file mode 100644 index 0000000000..87da1a90a9 --- /dev/null +++ b/flashinfer/gdn_kernels/__init__.py @@ -0,0 +1,33 @@ +""" +GDN (Gated Delta Rule) Kernels - CuTe DSL Implementations +========================================================= + +This module provides CuTe-DSL implementations of GDN kernels. + +The main gdn_decode.py and gdn_prefill.py files at the top level contain reference +implementations and JIT-compiled kernels. This submodule provides high-performance +CuTe DSL variants optimized for specific use cases. + +Exported Kernels: +- gated_delta_rule: BF16 hidden state decode kernel (T=1,2,3,4) +- GatedDeltaRuleKernel: Kernel class for advanced usage +""" + +from typing import Optional, Type + +try: + from .gdn_decode_bf16_state import ( + gated_delta_rule, + GatedDeltaRuleKernel, + ) + + _has_cute_dsl = True +except ImportError: + _has_cute_dsl = False + gated_delta_rule = None # type: ignore + GatedDeltaRuleKernel: Optional[Type] = None # type: ignore + +__all__ = [ + "gated_delta_rule", + "GatedDeltaRuleKernel", +] diff --git a/flashinfer/cute_dsl/gated_delta_rule.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py similarity index 99% rename from flashinfer/cute_dsl/gated_delta_rule.py rename to flashinfer/gdn_kernels/gdn_decode_bf16_state.py index 131350ab7d..9bbbd849c6 100644 --- a/flashinfer/cute_dsl/gated_delta_rule.py +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. -Gated Delta Rule Kernel (Unified Modular) using CuTe-DSL -======================================================== +Gated Delta Rule Decode Kernel (BF16 Hidden State) - CuTe-DSL Implementation +============================================================================ + +RELOCATED: This file was previously located at flashinfer/cute_dsl/gated_delta_rule.py + and has been moved to flashinfer/gdn_decode/gdn_decode_bf16_state.py + to better reflect its domain-specific purpose (GDN decode with BF16 state). High-performance CUDA kernel implementing the Gated Delta Rule linear attention mechanism for decode-phase inference, supporting sequence lengths T=1, T=2, T=3, T=4. diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index b234ff4e93..963198c8a6 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -43,7 +43,7 @@ # Import the gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last layout) try: - from flashinfer.cute_dsl.gated_delta_rule import ( + from flashinfer.gdn_kernels.gdn_decode_bf16_state import ( gated_delta_rule as gdn_decode_klast_bf16_state, )