diff --git a/.gitignore b/.gitignore index c06b8448ca..1a89e54605 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ csrc/aot_default_additional_params.h # Microbenchmark files microbenchmark/ +flashinfer/cute_dsl/benchmark_gated_delta_rule.py # vscode .vscode/ diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 9ec10b8fa5..c1d18f810f 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -18,12 +18,21 @@ GDN Decode Benchmark This benchmark supports: -1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose +1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state 2. Single layout comparison: FlashInfer (CuTe DSL) vs Triton kernel (--compare) 3. MTP benchmark (--version mtp) +4. gdn_decode_klast_bf16_state benchmark (--version gdn_decode_klast_bf16_state) for T=1,2,3,4 + +Kernels benchmarked: +- FlashInfer Pretranspose [B, HV, V, K] (V-major layout) +- FlashInfer Nontranspose [B, HV, K, V] (K-major layout) +- Triton Pretranspose [B, HV, V, K] +- Triton Nontranspose [B, HV, K, V] +- gdn_decode_klast_bf16_state [B, HV, V, K] (K-fast layout, T=1..4, bf16 state) + from flashinfer.cute_dsl.gated_delta_rule Usage: - # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 # Single layout comparison: FlashInfer vs Triton @@ -35,7 +44,10 @@ # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 - # Use Qwen3-Next preset + # gdn_decode_klast_bf16_state benchmark (T=1,2,3,4) + python benchmarks/bench_gdn_decode.py --version gdn_decode_klast_bf16_state --batch-size 1 32 128 512 + + # Use Qwen3-Next preset (q=k=16, v=32, d=128) python benchmarks/bench_gdn_decode.py --preset qwen3-next --batch-size 1 32 128 512 """ @@ -50,6 +62,16 @@ ) from flashinfer.testing import bench_gpu_time +# Import the gdn_decode_klast_bf16_state kernel for benchmarking (T=1..4, bf16 state, K-last) +try: + from flashinfer.gdn_kernels.gdn_decode_bf16_state import ( + gated_delta_rule as gdn_decode_klast_bf16_state, + ) + + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True +except ImportError: + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False + # ============================================================================ # Utility Functions @@ -102,6 +124,7 @@ def gdn_decode_bytes( dtype: torch.dtype, seq_len: int = 1, disable_state_update: bool = False, + state_dtype_bytes: int = 4, # 4 for FP32, 2 for BF16 ) -> int: """ Calculate memory bytes for GDN. @@ -110,8 +133,8 @@ def gdn_decode_bytes( Includes: - Q, K, V tensors (input): [B, T, H, K] - dtype - - State tensor (input/output): [B, HV, K, V] - float32 - - Intermediate states (MTP only): [B, T, HV, K, V] - float32 + - State tensor (input/output): [B, HV, K, V] - state_dtype_bytes (FP32=4 or BF16=2) + - Intermediate states (MTP only): [B, T, HV, K, V] - state_dtype_bytes - GDN parameters: A_log (float32), a (dtype), dt_bias (dtype), b (dtype) - Output tensor: [B, T, HV, V] - dtype @@ -129,15 +152,19 @@ def gdn_decode_bytes( # Output tensor: [B, T, HV, V] o_bytes = batch_size * seq_len * num_o_heads * head_size * elem_size - # State tensor (float32): [B, HV, K, V] + # State tensor: [B, HV, K, V] # If disable_state_update=True: only read initial state # If disable_state_update=False: read initial + write final state if disable_state_update: # Read only (e.g., MTP verify mode) - state_bytes = batch_size * num_sab_heads * head_size * head_size * 4 + state_bytes = ( + batch_size * num_sab_heads * head_size * head_size * state_dtype_bytes + ) else: # Read + write (e.g., normal decode) - state_bytes = 2 * batch_size * num_sab_heads * head_size * head_size * 4 + state_bytes = ( + 2 * batch_size * num_sab_heads * head_size * head_size * state_dtype_bytes + ) # GDN parameters # A_log: [HV] - float32 @@ -149,12 +176,17 @@ def gdn_decode_bytes( # b: [B, T, HV] - dtype b_bytes = batch_size * seq_len * num_sab_heads * elem_size - # Intermediate states (float32): [B, T, HV, K, V] - only for MTP (seq_len > 1) + # Intermediate states: [B, T, HV, K, V] - only for MTP (seq_len > 1) # Write all T steps of intermediate states intermediate_bytes = 0 if seq_len > 1: intermediate_bytes = ( - batch_size * seq_len * num_sab_heads * head_size * head_size * 4 + batch_size + * seq_len + * num_sab_heads + * head_size + * head_size + * state_dtype_bytes ) total_bytes = ( @@ -1800,6 +1832,49 @@ def verify_correctness_pretranspose( # ============================================================================ +def gdn_decode_klast_bf16_state_wrapper( + q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4 + k: torch.Tensor, # [B, T, H_K, K] + v: torch.Tensor, # [B, T, HV, V] + state: torch.Tensor, # [B, HV, V, K] - K-last layout (pretranspose) + A_log: torch.Tensor, # [HV] + a: torch.Tensor, # [B, T, HV] + dt_bias: torch.Tensor, # [HV] + b: torch.Tensor, # [B, T, HV] + scale: float, + output: torch.Tensor, # [B, T, HV, V] - unused, kernel returns output directly + use_qk_l2norm: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, +): + """ + Wrapper for gdn_decode_klast_bf16_state GDN kernel. + Supports T=1,2,3,4 (sequence lengths up to 4). + Adapts the interface to match the benchmark's calling convention. + + Note: The kernel returns output directly, no copy needed. + """ + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + raise RuntimeError("gdn_decode_klast_bf16_state kernel is not available") + + # Call gdn_decode_klast_bf16_state kernel directly - no wrapper overhead + # Kernel modifies state in-place and returns output tensor + return gdn_decode_klast_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale, + ) + + def format_time(t): """Format time value, returning 'N/A' if None.""" return f"{t:>8.2f}" if t is not None else " N/A" @@ -1955,11 +2030,44 @@ def bench_all_layouts( results["tr_pretrans_us"] = None results["tr_nontrans_us"] = None + # ========== gdn_decode_klast_bf16_state Kernel (K-fast/pretranspose layout) ========== + if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + # gdn_decode_klast_bf16_state uses [B, HV, V, K] layout (K-fast, same as pretranspose) + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, # gdn_decode_klast_bf16_state uses BF16 state + device="cuda", + ) + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + try: + times = bench_gpu_time( + lambda: gdn_decode_klast_bf16_state_wrapper( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + results["gdn_decode_klast_bf16_state_us"] = np.median(times) * 1000 + except Exception as e: + results["gdn_decode_klast_bf16_state_us"] = None + print( + f" gdn_decode_klast_bf16_state kernel failed: {type(e).__name__}: {e}" + ) + else: + results["gdn_decode_klast_bf16_state_us"] = None + return results def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): - """Run benchmark comparing all layouts: FlashInfer/Triton x pretranspose/nontranspose.""" + """Run benchmark comparing all layouts: FlashInfer/Triton x pretranspose/nontranspose + CuTe-DSL.""" # Verify correctness first if requested if args.verify and TRITON_AVAILABLE: print("\n=== Correctness Verification ===") @@ -1995,24 +2103,26 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print(f" Nontranspose: ERROR - {type(e).__name__}") print() - print("\n" + "=" * 120) - print("GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose") + print("\n" + "=" * 160) + print( + "GDN Decode Benchmark (T=1): FlashInfer vs Triton vs gdn_decode_klast_bf16_state" + ) print( f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " f"v_heads={args.num_v_heads}, head_size={args.head_size}, " f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" ) - print("=" * 120) + print("=" * 160) print() print( - f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | " - f"{'FI/TR-Pre':>9} {'FI/TR-Non':>9} | {'Pre/Non-FI':>10} {'Pre/Non-TR':>10}" + f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | {'KlastBf16':>9} | " + f"{'FI/TR-Pre':>9} {'KlastBf16/FI':>11} {'KlastBf16/TR':>11}" ) print( - f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | " - f"{'speedup':>9} {'speedup':>9} | {'speedup':>10} {'speedup':>10}" + f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} | " + f"{'speedup':>9} {'speedup':>10} {'speedup':>10}" ) - print("-" * 120) + print("-" * 160) all_results = [] for batch_size in args.batch_size: @@ -2033,35 +2143,50 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): fi_non = result.get("fi_nontrans_us") tr_pre = result.get("tr_pretrans_us") tr_non = result.get("tr_nontrans_us") + klast_bf16_us = result.get("gdn_decode_klast_bf16_state_us") # FI/TR speedup (>1 means FI faster) fi_tr_pre = format_speedup(fi_pre, tr_pre) - fi_tr_non = format_speedup(fi_non, tr_non) - # Pre/Non speedup (>1 means pretranspose faster) - pre_non_fi = format_speedup(fi_pre, fi_non) - pre_non_tr = format_speedup(tr_pre, tr_non) + # gdn_decode_klast_bf16_state vs FI-PreTr speedup (>1 means klast_bf16 faster) + klast_bf16_fi_speedup = format_speedup(klast_bf16_us, fi_pre) + + # gdn_decode_klast_bf16_state vs TR-PreTr speedup (>1 means klast_bf16 faster) + klast_bf16_tr_speedup = format_speedup(klast_bf16_us, tr_pre) print( f"{batch_size:>6} | {format_time(fi_pre)} {format_time(fi_non)} | " - f"{format_time(tr_pre)} {format_time(tr_non)} | " - f"{fi_tr_pre} {fi_tr_non} | {pre_non_fi} {pre_non_tr}" + f"{format_time(tr_pre)} {format_time(tr_non)} | {format_time(klast_bf16_us)} | " + f"{fi_tr_pre} {klast_bf16_fi_speedup:>10} {klast_bf16_tr_speedup:>10}" ) - print("-" * 120) + print("-" * 160) print() print("Legend:") print(" FI-PreTr = FlashInfer Pretranspose [B, HV, V, K]") print(" FI-NonTr = FlashInfer Nontranspose [B, HV, K, V]") print(" TR-PreTr = Triton Pretranspose [B, HV, V, K]") print(" TR-NonTr = Triton Nontranspose [B, HV, K, V]") + print( + " KlastBf16 = gdn_decode_klast_bf16_state [B, HV, V, K] (K-fast layout, T=1..4, bf16 state)" + ) print(" FI/TR speedup > 1.0 means FlashInfer is faster than Triton") - print(" Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose") + print( + " KlastBf16/FI speedup > 1.0 means gdn_decode_klast_bf16_state is faster than FlashInfer Pretranspose" + ) + print( + " KlastBf16/TR speedup > 1.0 means gdn_decode_klast_bf16_state is faster than Triton Pretranspose" + ) print() # Summary statistics fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")] tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")] + klast_bf16_times = [ + r["gdn_decode_klast_bf16_state_us"] + for r in all_results + if r.get("gdn_decode_klast_bf16_state_us") + ] if fi_pre_times and tr_pre_times: speedups = [tr / fi for fi, tr in zip(fi_pre_times, tr_pre_times, strict=False)] @@ -2069,6 +2194,190 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) + if klast_bf16_times and fi_pre_times and len(klast_bf16_times) == len(fi_pre_times): + speedups = [ + fi / t for t, fi in zip(klast_bf16_times, fi_pre_times, strict=False) + ] + print( + f"gdn_decode_klast_bf16_state vs FlashInfer (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + ) + + if klast_bf16_times and tr_pre_times and len(klast_bf16_times) == len(tr_pre_times): + speedups = [ + tr / t for t, tr in zip(klast_bf16_times, tr_pre_times, strict=False) + ] + print( + f"gdn_decode_klast_bf16_state vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + ) + + +# ============================================================================ +# gdn_decode_klast_bf16_state Multi-Token Benchmark (T=1,2,3,4) +# ============================================================================ + + +def bench_gdn_decode_klast_bf16_state( + batch_size: int, + seq_len: int, # T=1,2,3,4 + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + warmup_iters: int = 10, + bench_iters: int = 100, +): + """Benchmark gdn_decode_klast_bf16_state kernel for T=1,2,3,4.""" + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + raise RuntimeError("gdn_decode_klast_bf16_state kernel is not available") + + assert seq_len in [1, 2, 3, 4], ( + f"gdn_decode_klast_bf16_state supports T=1,2,3,4, got T={seq_len}" + ) + + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs + T = seq_len + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Initial state: [B, HV, V, K] (K-fast layout, BF16) + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + + # Pre-allocate output + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) + kernel_times_ms = bench_gpu_time( + lambda: gdn_decode_klast_bf16_state_wrapper( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + + # Calculate metrics + kernel_median_ms = np.median(kernel_times_ms) + flops = gdn_decode_flops( + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len + ) + # gdn_decode_klast_bf16_state uses BF16 state (2 bytes), not FP32 (4 bytes) + bytes_accessed = gdn_decode_bytes( + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + dtype, + seq_len, + disable_state_update=False, + state_dtype_bytes=2, # BF16 state for gdn_decode_klast_bf16_state + ) + + kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 + kernel_tb_per_sec = ( + bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 + ) + + return { + "batch_size": batch_size, + "seq_len": seq_len, + "kernel_median_us": kernel_median_ms * 1000, + "kernel_tflops": kernel_tflops, + "kernel_tb_per_sec": kernel_tb_per_sec, + } + + +def run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm): + """Run gdn_decode_klast_bf16_state benchmark for T=1,2,3,4.""" + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + print("Error: gdn_decode_klast_bf16_state kernel is not available.") + print("Make sure flashinfer.cute_dsl.gated_delta_rule is importable.") + return + + # Filter seq_len to only valid values (1,2,3,4) + valid_seq_lens = [t for t in args.seq_len if t in [1, 2, 3, 4]] + if not valid_seq_lens: + print("Error: --seq-len must include values from [1, 2, 3, 4]") + return + + print("\n" + "=" * 100) + print(f"gdn_decode_klast_bf16_state GDN Benchmark (T={valid_seq_lens})") + print( + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " + f"v_heads={args.num_v_heads}, head_size={args.head_size}, " + f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" + ) + print("=" * 100) + print() + print(f"{'batch':>6} {'T':>4} {'time(us)':>10} {'TFLOPS':>10} {'TB/s':>10}") + print("-" * 100) + + all_results = [] + for batch_size in args.batch_size: + for seq_len in valid_seq_lens: + try: + result = bench_gdn_decode_klast_bf16_state( + batch_size=batch_size, + seq_len=seq_len, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + all_results.append(result) + + print( + f"{result['batch_size']:>6} {result['seq_len']:>4} " + f"{result['kernel_median_us']:>10.2f} " + f"{result['kernel_tflops']:>10.2f} " + f"{result['kernel_tb_per_sec']:>10.2f}" + ) + except Exception as e: + print( + f"{batch_size:>6} {seq_len:>4} {'ERROR':>10} - {type(e).__name__}: {e}" + ) + + print("-" * 100) + print() + + # Summary by T value + for t in valid_seq_lens: + t_results = [r for r in all_results if r["seq_len"] == t] + if t_results: + avg_time = np.mean([r["kernel_median_us"] for r in t_results]) + avg_tflops = np.mean([r["kernel_tflops"] for r in t_results]) + print( + f"T={t}: Average time={avg_time:.2f}us, Average TFLOPS={avg_tflops:.2f}" + ) + # ============================================================================ # Main Entry Points @@ -2357,7 +2666,7 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + Improved CuTe-DSL) python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 # Single layout comparison: FlashInfer vs Triton (nontranspose) @@ -2371,6 +2680,9 @@ def main(): # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 + + # gdn_decode_klast_bf16_state benchmark (T=1,2,3,4) + python benchmarks/bench_gdn_decode.py --version gdn_decode_klast_bf16_state --batch-size 1 32 128 512 """, ) parser.add_argument( @@ -2402,16 +2714,22 @@ def main(): parser.add_argument( "--version", type=str, - choices=["pretranspose", "nontranspose", "mtp", "all"], + choices=[ + "pretranspose", + "nontranspose", + "mtp", + "gdn_decode_klast_bf16_state", + "all", + ], default="nontranspose", - help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), or all", + help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), gdn_decode_klast_bf16_state (T=1..4, bf16 state, K-last), or all", ) parser.add_argument( "--seq-len", type=int, nargs="+", - default=[2, 4, 8], - help="Sequence lengths for MTP benchmark (T > 1)", + default=[1, 2, 3, 4], + help="Sequence lengths: for MTP use T>1, for gdn_decode_klast_bf16_state use T=1,2,3,4", ) parser.add_argument( "--cache-intermediate-states", @@ -2466,8 +2784,11 @@ def main(): run_comparison_benchmark(args, dtype, use_qk_l2norm) else: run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) + elif args.version == "gdn_decode_klast_bf16_state": + # gdn_decode_klast_bf16_state benchmark for T=1,2,3,4 + run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm) else: - # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) run_all_layouts_benchmark(args, dtype, use_qk_l2norm) diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 940031453d..9e6762531a 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -53,6 +53,10 @@ add_rmsnorm_fp4quant, AddRMSNormFP4QuantKernel, ) + from .gated_delta_rule import ( + gated_delta_rule, + GatedDeltaRuleKernel, + ) __all__ = [ # Utils (always available) @@ -79,4 +83,7 @@ # Add + RMSNorm + FP4 Quantization "add_rmsnorm_fp4quant", "AddRMSNormFP4QuantKernel", + # Gated Delta Rule + "gated_delta_rule", + "GatedDeltaRuleKernel", ] diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index e64c231686..26c742e839 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -60,6 +60,18 @@ def flashinfer_api(func): # type: ignore[misc] return func +# GDN decode K-last bf16 state kernel (T=1..4, bf16 state, K-last layout) - optional backend +try: + from .gdn_kernels.gdn_decode_bf16_state import ( + gated_delta_rule as _gated_delta_rule_gdn_decode_klast_bf16_state, + ) + + _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True +except ImportError: + _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False + _gated_delta_rule_gdn_decode_klast_bf16_state = None + + # ============================================================================ # Global configuration for PRETRANSPOSE version ([B*HV, V, K]) # ============================================================================ @@ -952,8 +964,9 @@ def gated_delta_rule_decode_pretranspose( v (torch.Tensor): Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. state (torch.Tensor): - Current state of shape ``[B, HV, V, K]`` (v-major layout). - Must be float32. Will be updated in-place. + Current state of shape ``[B, HV, V, K]`` (v-major / K-last layout). + Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend + when T in 1..4 and K=V=128. Will be updated in-place. A_log (torch.Tensor): Log decay parameter of shape ``[HV]``. Must be float32. a (torch.Tensor): @@ -978,19 +991,61 @@ def gated_delta_rule_decode_pretranspose( Note: - Requires SM90 (Hopper) architecture - State is updated in-place - - K and V must be multiples of 4 for vectorized loads - - State layout is v-major: [B, HV, V, K] + - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 + and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used. + - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes B, T, H, K = q.shape - assert T == 1, f"Decode only supports T=1, got T={T}" _, _, HV, V = v.shape - # Validate state shape + # Validate state shape (Qwen-style K-last: [B, HV, V, K]) assert state.shape == (B, HV, V, K), ( f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}" ) + # Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V=128 + use_gdn_decode_klast_bf16_state = ( + _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE + and state.dtype == torch.bfloat16 + and T in (1, 2, 3, 4) + and K == 128 + and V == 128 + ) + if use_gdn_decode_klast_bf16_state: + assert q.dtype in (torch.float16, torch.bfloat16), ( + f"q must be float16/bfloat16, got {q.dtype}" + ) + assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" + scale_val = K**-0.5 if scale is None else scale + out = _gated_delta_rule_gdn_decode_klast_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale_val, + ) + output_provided = output is not None + target_dtype = output.dtype if output_provided else q.dtype + if output is not None: + output.copy_(out) + else: + output = out + if output.dtype != target_dtype: + output = output.to(target_dtype) + return output, state + + # Legacy path: T=1 only, float32 state + assert T == 1, f"Decode only supports T=1, got T={T}" + assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" + # Validate K and V constraints assert K >= 128, f"K must be at least 128, got K={K}" assert V >= 128, f"V must be at least 128, got V={V}" @@ -1002,7 +1057,6 @@ def gated_delta_rule_decode_pretranspose( assert q.dtype in (torch.float16, torch.bfloat16), ( f"q must be float16/bfloat16, got {q.dtype}" ) - assert state.dtype == torch.float32, f"state must be float32, got {state.dtype}" assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" # Set default scale diff --git a/flashinfer/gdn_kernels/__init__.py b/flashinfer/gdn_kernels/__init__.py new file mode 100644 index 0000000000..87da1a90a9 --- /dev/null +++ b/flashinfer/gdn_kernels/__init__.py @@ -0,0 +1,33 @@ +""" +GDN (Gated Delta Rule) Kernels - CuTe DSL Implementations +========================================================= + +This module provides CuTe-DSL implementations of GDN kernels. + +The main gdn_decode.py and gdn_prefill.py files at the top level contain reference +implementations and JIT-compiled kernels. This submodule provides high-performance +CuTe DSL variants optimized for specific use cases. + +Exported Kernels: +- gated_delta_rule: BF16 hidden state decode kernel (T=1,2,3,4) +- GatedDeltaRuleKernel: Kernel class for advanced usage +""" + +from typing import Optional, Type + +try: + from .gdn_decode_bf16_state import ( + gated_delta_rule, + GatedDeltaRuleKernel, + ) + + _has_cute_dsl = True +except ImportError: + _has_cute_dsl = False + gated_delta_rule = None # type: ignore + GatedDeltaRuleKernel: Optional[Type] = None # type: ignore + +__all__ = [ + "gated_delta_rule", + "GatedDeltaRuleKernel", +] diff --git a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py new file mode 100644 index 0000000000..9bbbd849c6 --- /dev/null +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -0,0 +1,2062 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Gated Delta Rule Decode Kernel (BF16 Hidden State) - CuTe-DSL Implementation +============================================================================ + +RELOCATED: This file was previously located at flashinfer/cute_dsl/gated_delta_rule.py + and has been moved to flashinfer/gdn_decode/gdn_decode_bf16_state.py + to better reflect its domain-specific purpose (GDN decode with BF16 state). + +High-performance CUDA kernel implementing the Gated Delta Rule linear attention +mechanism for decode-phase inference, supporting sequence lengths T=1, T=2, T=3, T=4. + +Key Features: +- Unified kernel architecture: T=2/3/4 share a single compile-time specialized kernel + using Constexpr dispatch, while T=1 uses a separate kernel with persistent K-in-registers +- L2-normalized Q/K with configurable scale +- Gated exponential decay of hidden state H via softplus +- Delta rule updates: v_delta = beta * (v - pred) +- Bank-conflict-free cross-warp reductions +- Async H memory loading with aggressive pipelining +- BF16 tensors with FP32 compute for numerical stability +- GQA (grouped-query attention) support with configurable H (query) and HV (value) heads +""" + +import math +from typing import Optional + +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda +import torch +from cutlass import utils +from cutlass._mlir.dialects import nvvm +from cutlass.cute.runtime import from_dlpack + +# ============================================================================== +# CONSTANTS +# ============================================================================== +H_SMEM_PADDING = 8 +H_SMEM_STRIDE = 128 + H_SMEM_PADDING + + +# ============================================================================== +# SHARED HELPER FUNCTIONS +# ============================================================================== + + +@cute.jit +def write_h_chunk_to_smem(h_chunk_f32, h_sh_chunk, lane_idx, k_base): + """Write F32 register H chunk to BF16 SMEM.""" + for i in cutlass.range_constexpr(32): + h_sh_chunk[lane_idx, k_base + i] = h_chunk_f32[i].to(cutlass.BFloat16) + + +@cute.jit +def store_h_smem_to_gmem(h_sh_chunk, h_out, tidx, v_row_offset): + """Store H from SMEM to GMEM using 128-bit stores.""" + copy_bits = 128 + copy_elems = copy_bits // cutlass.BFloat16.width + + thr_layout = cute.make_layout((16, 8), stride=(8, 1)) + val_layout = cute.make_layout((1, copy_elems)) + + from cutlass.cute.nvgpu import CopyUniversalOp + + atom_store = cute.make_copy_atom( + CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=copy_bits + ) + tiled_copy = cute.make_tiled_copy_tv(atom_store, thr_layout, val_layout) + thr_copy = tiled_copy.get_slice(tidx) + + for row_iter in cutlass.range_constexpr(2): + for col_iter in cutlass.range_constexpr(2): + s_tile = cute.local_tile(h_sh_chunk, (16, 64), (row_iter, col_iter)) + g_tile = cute.local_tile( + h_out, (16, 64), (row_iter + (v_row_offset // 16), col_iter) + ) + tS = thr_copy.partition_S(s_tile) + tD = thr_copy.partition_D(g_tile) + cute.copy(atom_store, tS, tD) + + +@cute.jit +def load_h_chunk_async(h_sh_chunk, h_global, tidx, row_offset): + """Load H chunk from GMEM to SMEM using async copy.""" + copy_bits = 128 + copy_elems = copy_bits // cutlass.BFloat16.width + + thr_layout = cute.make_layout((16, 8), stride=(8, 1)) + val_layout = cute.make_layout((1, copy_elems)) + + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + cutlass.BFloat16, + num_bits_per_copy=copy_bits, + ) + tiled_copy = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + thr_copy = tiled_copy.get_slice(tidx) + + for row_iter in cutlass.range_constexpr(2): + for col_iter in cutlass.range_constexpr(2): + g_tile = cute.local_tile( + h_global, (16, 64), (row_iter + (row_offset // 16), col_iter) + ) + s_tile = cute.local_tile(h_sh_chunk, (16, 64), (row_iter, col_iter)) + tS = thr_copy.partition_S(g_tile) + tD = thr_copy.partition_D(s_tile) + cute.copy(atom_async_copy, tS, tD) + + +@cute.jit +def compute_single_gate( + alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold +): + """Compute gate values (g_exp, beta) for a single token.""" + x = alpha + dt_bias_val + beta_x = softplus_beta * x + softplus_x = cutlass.Float32(0.0) + if beta_x <= softplus_threshold: + softplus_x = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + cute.exp(beta_x, fastmath=True), fastmath=True + ) + else: + softplus_x = x + g = -cute.exp(A_log_val, fastmath=True) * softplus_x + g_exp = cute.exp(g, fastmath=True) + beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-beta_raw, fastmath=True) + ) + return g_exp, beta + + +@cute.jit +def normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps): + """L2-normalize Q and K vectors, then store to shared memory.""" + q_reg = cute.make_rmem_tensor((4,), cutlass.Float32) + k_reg = cute.make_rmem_tensor((4,), cutlass.Float32) + + for i in cutlass.range_constexpr(4): + q_reg[i] = q_head[lane_idx + i * 32].to(cutlass.Float32) + k_reg[i] = k_head[lane_idx + i * 32].to(cutlass.Float32) + + q_sum_sq = cutlass.Float32(0.0) + k_sum_sq = cutlass.Float32(0.0) + q_sum_sq2 = cutlass.Float32(0.0) + k_sum_sq2 = cutlass.Float32(0.0) + + for i in cutlass.range_constexpr(0, 4, 2): + q_sum_sq, q_sum_sq2 = cute.arch.fma_packed_f32x2( + src_a=(q_reg[i], q_reg[i + 1]), + src_b=(q_reg[i], q_reg[i + 1]), + src_c=(q_sum_sq, q_sum_sq2), + ) + k_sum_sq, k_sum_sq2 = cute.arch.fma_packed_f32x2( + src_a=(k_reg[i], k_reg[i + 1]), + src_b=(k_reg[i], k_reg[i + 1]), + src_c=(k_sum_sq, k_sum_sq2), + ) + + q_sum_sq = q_sum_sq + q_sum_sq2 + k_sum_sq = k_sum_sq + k_sum_sq2 + + for i in cutlass.range_constexpr(5): + q_sum_sq = q_sum_sq + cute.arch.shuffle_sync_bfly( + q_sum_sq, offset=1 << i, mask=0xFFFFFFFF + ) + k_sum_sq = k_sum_sq + cute.arch.shuffle_sync_bfly( + k_sum_sq, offset=1 << i, mask=0xFFFFFFFF + ) + + q_norm = cute.rsqrt(q_sum_sq + eps, fastmath=True) + k_norm = cute.rsqrt(k_sum_sq + eps, fastmath=True) + q_scale_factor = q_norm * scale + + for i in cutlass.range_constexpr(4): + q_sh[lane_idx + i * 32] = q_reg[i] * q_scale_factor + k_sh[lane_idx + i * 32] = k_reg[i] * k_norm + + +@cute.jit +def load_v_to_smem(v_head, v_sh, tidx): + """Load V values from GMEM to SMEM.""" + v_sh[tidx] = v_head[tidx].to(cutlass.Float32) + + +@cute.jit +def load_kq_chunk_from_smem(kq_sh, kq_chunk, k_base): + """Load K or Q chunk from SMEM to registers.""" + for i in cutlass.range_constexpr(32): + kq_chunk[i] = kq_sh[k_base + i] + + +@cute.jit +def decay_h_from_smem_and_compute_pred( + h_sh_chunk, h_chunk, kq_chunk, g_exp, lane_idx, k_base +): + """Load H from SMEM, apply decay, and compute pred = sum_k(h * k).""" + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(kq_chunk[i], kq_chunk[i + 1]), + src_c=(pred, pred2), + ) + + pred = pred + pred2 + return pred + + +@cute.jit +def update_h_with_delta(h_chunk, kq_chunk, v_delta): + """Update H with delta: h = h + k * v_delta.""" + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(kq_chunk[i], kq_chunk[i + 1]), + src_b=(v_delta, v_delta), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + +@cute.jit +def compute_output(h_chunk, kq_chunk): + """Compute output = sum_k(h * q).""" + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(kq_chunk[i], kq_chunk[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + return out + + +@cute.jit +def decay_h_in_place(h_chunk, g_exp): + """Apply decay to H in place: h = h * g_exp.""" + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + + +@cute.jit +def cross_warp_reduce_single(reduce_sh, slot, warp_idx, lane_idx, value): + """ + Cross-warp reduction for a single value using bank-conflict-free layout. + Layout: [slot, lane_idx, warp_idx] + """ + reduce_sh[slot, lane_idx, warp_idx] = value + cute.arch.sync_threads() + reduced_value = ( + reduce_sh[slot, lane_idx, 0] + + reduce_sh[slot, lane_idx, 1] + + reduce_sh[slot, lane_idx, 2] + + reduce_sh[slot, lane_idx, 3] + ) + return reduced_value + + +@cute.jit +def cross_warp_reduce_two(reduce_sh, slot1, slot2, warp_idx, lane_idx, value1, value2): + """ + Cross-warp reduction for two values simultaneously using bank-conflict-free layout. + Layout: [slot, lane_idx, warp_idx] + """ + reduce_sh[slot1, lane_idx, warp_idx] = value1 + reduce_sh[slot2, lane_idx, warp_idx] = value2 + cute.arch.sync_threads() + reduced1 = ( + reduce_sh[slot1, lane_idx, 0] + + reduce_sh[slot1, lane_idx, 1] + + reduce_sh[slot1, lane_idx, 2] + + reduce_sh[slot1, lane_idx, 3] + ) + reduced2 = ( + reduce_sh[slot2, lane_idx, 0] + + reduce_sh[slot2, lane_idx, 1] + + reduce_sh[slot2, lane_idx, 2] + + reduce_sh[slot2, lane_idx, 3] + ) + return reduced1, reduced2 + + +@cute.jit +def process_first_token( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh, + q_sh, + v_sh, + reduce_sh, + o_head, + g_exp, + beta, + v_offset, + pred_slot, + warp_idx, + lane_idx, + k_base, +): + """ + Process the first token in a V-chunk (T=0). + - Load K from SMEM + - Decay H from SMEM and compute pred + - Cross-warp reduce pred (uses pred_slot) + - Update H with delta + - Load Q and compute output + Returns: out (partial output, not yet reduced) + """ + # Load K for this token + load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + + # Decay H from SMEM and compute pred = H * K + pred = decay_h_from_smem_and_compute_pred( + h_sh_chunk_curr, h_chunk, kq_chunk, g_exp, lane_idx, k_base + ) + + # Reduce pred across warps (slot 0 for first token) + pred_final = cross_warp_reduce_single( + reduce_sh, pred_slot, warp_idx, lane_idx, pred + ) + + # Compute delta and update H + v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta + update_h_with_delta(h_chunk, kq_chunk, v_delta) + + # Load Q and compute output + load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) + out = compute_output(h_chunk, kq_chunk) + + return out + + +@cute.jit +def process_middle_token( + h_chunk, + kq_chunk, + k_sh, + q_sh, + v_sh, + reduce_sh, + o_head_prev, + g_exp, + beta, + v_offset, + out_slot_prev, + pred_slot, + out_prev, + warp_idx, + lane_idx, + k_base, +): + """ + Process a middle token (T=1, T=2 for T=4 kernel). + - Decay H in place + - Load K, compute pred + - Joint reduction of (prev_out, this_pred) + - Store prev output + - Update H with delta + - Load Q and compute output + Returns: out (partial output, not yet reduced) + """ + # Decay H in place + decay_h_in_place(h_chunk, g_exp) + + # Load K and compute pred + load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + pred = compute_output(h_chunk, kq_chunk) + + # Joint reduction: reduce out_prev and pred together + out_prev_final, pred_final = cross_warp_reduce_two( + reduce_sh, out_slot_prev, pred_slot, warp_idx, lane_idx, out_prev, pred + ) + + # Store previous token's output + if warp_idx == 0: + o_head_prev[v_offset + lane_idx] = out_prev_final.to(cutlass.BFloat16) + + # Compute delta and update H + v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta + update_h_with_delta(h_chunk, kq_chunk, v_delta) + + # Load Q and compute output + load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) + out = compute_output(h_chunk, kq_chunk) + + return out + + +@cute.jit +def process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh, + q_sh, + v_sh, + reduce_sh, + o_head_prev, + o_head_last, + g_exp, + beta, + v_offset, + out_slot_prev, + pred_slot, + out_slot_last, + out_prev, + warp_idx, + lane_idx, + k_base, +): + """ + Process the last token and finalize the V-chunk. + - Decay H in place + - Load K, compute pred + - Joint reduction of (prev_out, this_pred) + - Store prev output + - Update H with delta + - Compute last output and reduce + - Write H back to SMEM + - Store last output + """ + # Decay H in place + decay_h_in_place(h_chunk, g_exp) + + # Load K and compute pred + load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + pred = compute_output(h_chunk, kq_chunk) + + # Joint reduction: reduce out_prev and pred together + out_prev_final, pred_final = cross_warp_reduce_two( + reduce_sh, out_slot_prev, pred_slot, warp_idx, lane_idx, out_prev, pred + ) + + # Store previous token's output + if warp_idx == 0: + o_head_prev[v_offset + lane_idx] = out_prev_final.to(cutlass.BFloat16) + + # Compute delta and update H + v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta + update_h_with_delta(h_chunk, kq_chunk, v_delta) + + # Compute last output + load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) + out_last = compute_output(h_chunk, kq_chunk) + + # Final reduction and store + out_last_final = cross_warp_reduce_single( + reduce_sh, out_slot_last, warp_idx, lane_idx, out_last + ) + write_h_chunk_to_smem(h_chunk, h_sh_chunk_curr, lane_idx, k_base) + if warp_idx == 0: + o_head_last[v_offset + lane_idx] = out_last_final.to(cutlass.BFloat16) + + +# ============================================================================== +# UNIFIED V-CHUNK PROCESSING FOR SEQLEN=2/3/4 +# ============================================================================== + + +@cute.jit +def process_vchunk_unified_234( + h_sh_chunk_curr, + h_sh_chunk_prev, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + v_offset, + prev_v_offset, + store_prev, + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS: cutlass.Constexpr[int], +): + """ + Unified V-chunk processing for 2, 3, or 4 tokens using Constexpr parameter. + + This function handles V-chunk processing for all multi-token cases (T=2, T=3, T=4) + using compile-time specialization via NUM_TOKENS. + + Pattern: + - Token 0: First token (always) + - Tokens 1 to NUM_TOKENS-2: Middle tokens (compile-time unrolled) + - Token NUM_TOKENS-1: Last token (always) + """ + # Store previous H chunk if needed + if store_prev: + store_h_smem_to_gmem(h_sh_chunk_prev, h_out, tidx, prev_v_offset) + + # Token 0: First token processing (always executed) + out0 = process_first_token( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh0, + q_sh0, + v_sh0, + reduce_sh, + o_head0, + g_exp0, + beta0, + v_offset, + 0, # pred_slot=0 + warp_idx, + lane_idx, + k_base, + ) + + # Compile-time dispatch based on NUM_TOKENS + if NUM_TOKENS == 2: + # For T=2: Token 1 is the last token + process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh1, + q_sh1, + v_sh1, + reduce_sh, + o_head0, + o_head1, + g_exp1, + beta1, + v_offset, + 1, + 2, + 3, # out_slot_prev=1, pred_slot=2, out_slot_last=3 + out0, + warp_idx, + lane_idx, + k_base, + ) + elif NUM_TOKENS == 3: + # For T=3: Token 1 is middle, Token 2 is last + out1 = process_middle_token( + h_chunk, + kq_chunk, + k_sh1, + q_sh1, + v_sh1, + reduce_sh, + o_head0, + g_exp1, + beta1, + v_offset, + 1, + 2, # out_slot_prev=1, pred_slot=2 + out0, + warp_idx, + lane_idx, + k_base, + ) + process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh2, + q_sh2, + v_sh2, + reduce_sh, + o_head1, + o_head2, + g_exp2, + beta2, + v_offset, + 3, + 4, + 5, # out_slot_prev=3, pred_slot=4, out_slot_last=5 + out1, + warp_idx, + lane_idx, + k_base, + ) + else: + # For T=4: Tokens 1,2 are middle, Token 3 is last + out1 = process_middle_token( + h_chunk, + kq_chunk, + k_sh1, + q_sh1, + v_sh1, + reduce_sh, + o_head0, + g_exp1, + beta1, + v_offset, + 1, + 2, # out_slot_prev=1, pred_slot=2 + out0, + warp_idx, + lane_idx, + k_base, + ) + out2 = process_middle_token( + h_chunk, + kq_chunk, + k_sh2, + q_sh2, + v_sh2, + reduce_sh, + o_head1, + g_exp2, + beta2, + v_offset, + 3, + 4, # out_slot_prev=3, pred_slot=4 + out1, + warp_idx, + lane_idx, + k_base, + ) + # Last token for NUM_TOKENS=4: Token 3 + process_last_token_and_finish( + h_sh_chunk_curr, + h_chunk, + kq_chunk, + k_sh3, + q_sh3, + v_sh3, + reduce_sh, + o_head2, + o_head3, + g_exp3, + beta3, + v_offset, + 5, + 6, + 7, # out_slot_prev=5, pred_slot=6, out_slot_last=7 + out2, + warp_idx, + lane_idx, + k_base, + ) + + +# ============================================================================== +# SEQLEN=1 KERNEL (Persistent K Optimization) +# ============================================================================== + + +@cute.kernel +def gated_delta_rule_decode_kernel_seqlen1( + gQ: cute.Tensor, + gK: cute.Tensor, + gV: cute.Tensor, + ga: cute.Tensor, + gb: cute.Tensor, + gA_log: cute.Tensor, + gdt_bias: cute.Tensor, + gH: cute.Tensor, + gO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, +): + """ + Seqlen=1 kernel with persistent K optimization. + OPTIMIZATIONS: + 1. PERSISTENT K IN REGISTERS ONLY: K[k_base:k_base+32] kept for entire kernel + Q is reloaded per chunk (lower register pressure than V3) + 2. AGGRESSIVE PIPELINING: Load chunks 2 ahead, store during next compute + 3. [4,32] CROSS-WARP REDUCTION: Correct lane-preserving reduction + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + HV = cutlass.Int32(gV.shape[2]) + H = cutlass.Int32(gQ.shape[2]) + + batch_idx = bidx // HV + value_head_idx = bidx % HV + query_head_idx = value_head_idx // (HV // H) + + smem = utils.SmemAllocator() + + # Compute gates using shared helper + alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + A_log_val = gA_log[value_head_idx] + dt_bias_val = gdt_bias[value_head_idx] + g_exp, beta = compute_single_gate( + alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Allocate SMEM + h_sh_chunk0 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk1 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk2 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk3 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + + q_sh = smem.allocate_tensor(cutlass.Float32, 128) + k_sh = smem.allocate_tensor(cutlass.Float32, 128) + + # pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + # out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) + pred_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + out_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + + h_global = gH[(batch_idx, value_head_idx, None, None)] + + # Launch first 2 async loads + load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) + nvvm.cp_async_commit_group() + + # L2 normalization + q_head = gQ[(batch_idx, 0, query_head_idx, None)] + k_head = gK[(batch_idx, 0, query_head_idx, None)] + + warp_idx = tidx // 32 + lane_idx = tidx % 32 + + # Use shared helper for Q/K normalization (only warp 0 does the work) + if warp_idx == 0: + normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) + + cute.arch.sync_threads() + + # Load V + v_head = gV[(batch_idx, 0, value_head_idx, None)] + v_sh = smem.allocate_tensor(cutlass.Float32, 128) + v_sh[tidx] = v_head[tidx].to(cutlass.Float32) + + # Registers: h_chunk + k_chunk (persistent) + qk_temp (reused for Q) + h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) # PERSISTENT K! + qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) + + k_base = warp_idx * 32 + + # Load K ONCE - keep for entire kernel + for i in cutlass.range_constexpr(32): + k_chunk[i] = k_sh[k_base + i] + + h_out = gH[(batch_idx, value_head_idx, None, None)] + o_head = gO[(batch_idx, 0, value_head_idx, None)] + + # ======================================================================== + # CHUNK 0 + # ======================================================================== + nvvm.cp_async_wait_group(1) + cute.arch.sync_threads() + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk0[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk0[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[lane_idx, warp_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] + ) + + v_val = (v_sh[lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + # Load Q for output computation + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[lane_idx, warp_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk0, lane_idx, k_base) + if warp_idx == 0: + o_head[lane_idx] = out_final.to(cutlass.BFloat16) + + # ======================================================================== + # CHUNK 1 + # ======================================================================== + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + + load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) + nvvm.cp_async_commit_group() + + store_h_smem_to_gmem(h_sh_chunk0, h_out, tidx, 0) + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk1[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk1[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[lane_idx, warp_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] + ) + + v_val = (v_sh[32 + lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[lane_idx, warp_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk1, lane_idx, k_base) + if warp_idx == 0: + o_head[32 + lane_idx] = out_final.to(cutlass.BFloat16) + + # ======================================================================== + # CHUNK 2 + # ======================================================================== + nvvm.cp_async_wait_group(1) + cute.arch.sync_threads() + + store_h_smem_to_gmem(h_sh_chunk1, h_out, tidx, 32) + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk2[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk2[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[lane_idx, warp_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] + ) + + v_val = (v_sh[64 + lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[lane_idx, warp_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk2, lane_idx, k_base) + if warp_idx == 0: + o_head[64 + lane_idx] = out_final.to(cutlass.BFloat16) + + # ======================================================================== + # CHUNK 3 + # ======================================================================== + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + + store_h_smem_to_gmem(h_sh_chunk2, h_out, tidx, 64) + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk3[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk3[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[lane_idx, warp_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] + ) + + v_val = (v_sh[96 + lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[lane_idx, warp_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk3, lane_idx, k_base) + if warp_idx == 0: + o_head[96 + lane_idx] = out_final.to(cutlass.BFloat16) + + cute.arch.sync_threads() + store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) + + +# ============================================================================== +# UNIFIED SEQLEN=2/3/4 MAIN KERNEL +# ============================================================================== + + +@cute.kernel +def gated_delta_rule_decode_kernel_seqlen234_unified( + gQ: cute.Tensor, # [B, T=2/3/4, H, K=128] + gK: cute.Tensor, # [B, T=2/3/4, H, K=128] + gV: cute.Tensor, # [B, T=2/3/4, HV, V=128] + ga: cute.Tensor, # [B, T=2/3/4, HV] + gb: cute.Tensor, # [B, T=2/3/4, HV] + gA_log: cute.Tensor, # [HV] + gdt_bias: cute.Tensor, # [HV] + gH: cute.Tensor, # [B, HV, V=128, K=128] - K-fast layout + gO: cute.Tensor, # [B, T=2/3/4, HV, V=128] + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + NUM_TOKENS: cutlass.Constexpr[int], # 2, 3, or 4 +): + """ + Unified kernel for Seqlen=2, Seqlen=3 and Seqlen=4 with compile-time specialization. + + Uses cutlass.Constexpr[int] NUM_TOKENS parameter to eliminate dead code paths: + - NUM_TOKENS=2: 4-slot reduce_sh, 2 Q/K/V buffers, 2 gates + - NUM_TOKENS=3: 6-slot reduce_sh, 3 Q/K/V buffers, 3 gates + - NUM_TOKENS=4: 8-slot reduce_sh, 4 Q/K/V buffers, 4 gates + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + HV = cutlass.Int32(gV.shape[2]) + H = cutlass.Int32(gQ.shape[2]) + + batch_idx = bidx // HV + value_head_idx = bidx % HV + query_head_idx = value_head_idx // (HV // H) + + warp_idx = tidx // 32 + lane_idx = tidx % 32 + k_base = warp_idx * 32 + + smem = utils.SmemAllocator() + + # SMEM Allocation - H chunks + h_sh_chunk0 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk1 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk2 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + h_sh_chunk3 = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + + # Q/K buffers for tokens 0 and 1 (always needed for T>=2) + q_sh0 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh0 = smem.allocate_tensor(cutlass.Float32, 128) + q_sh1 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh1 = smem.allocate_tensor(cutlass.Float32, 128) + + # Q/K buffers for token 2 (only for NUM_TOKENS >= 3) + q_sh2 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh2 = smem.allocate_tensor(cutlass.Float32, 128) + + # Q/K buffers for token 3 (only for NUM_TOKENS=4) + q_sh3 = smem.allocate_tensor(cutlass.Float32, 128) + k_sh3 = smem.allocate_tensor(cutlass.Float32, 128) + + # V buffers + v_sh0 = smem.allocate_tensor(cutlass.Float32, 128) + v_sh1 = smem.allocate_tensor(cutlass.Float32, 128) + v_sh2 = smem.allocate_tensor(cutlass.Float32, 128) + v_sh3 = smem.allocate_tensor(cutlass.Float32, 128) + + # Bank-conflict-free reduce_sh: [slot, lane_idx, warp_idx] + reduce_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128, 4, 1)) + ) + + # Register allocation + h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + kq_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + + # Gate computation - always compute gates 0, 1 (for T>=2) + A_log_val = gA_log[value_head_idx] + dt_bias_val = gdt_bias[value_head_idx] + + alpha0 = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + beta_raw0 = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + g_exp0, beta0 = compute_single_gate( + alpha0, beta_raw0, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + alpha1 = ga[(batch_idx, 1, value_head_idx)].to(cutlass.Float32) + beta_raw1 = gb[(batch_idx, 1, value_head_idx)].to(cutlass.Float32) + g_exp1, beta1 = compute_single_gate( + alpha1, beta_raw1, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Gate 2 - only for NUM_TOKENS >= 3 + g_exp2 = cutlass.Float32(0.0) + beta2 = cutlass.Float32(0.0) + if NUM_TOKENS >= 3: + alpha2 = ga[(batch_idx, 2, value_head_idx)].to(cutlass.Float32) + beta_raw2 = gb[(batch_idx, 2, value_head_idx)].to(cutlass.Float32) + g_exp2, beta2 = compute_single_gate( + alpha2, beta_raw2, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Gate 3 - only for NUM_TOKENS = 4 + g_exp3 = cutlass.Float32(0.0) + beta3 = cutlass.Float32(0.0) + if NUM_TOKENS == 4: + alpha3 = ga[(batch_idx, 3, value_head_idx)].to(cutlass.Float32) + beta_raw3 = gb[(batch_idx, 3, value_head_idx)].to(cutlass.Float32) + g_exp3, beta3 = compute_single_gate( + alpha3, beta_raw3, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + # Upfront H loading + h_global = gH[(batch_idx, value_head_idx, None, None)] + load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) + nvvm.cp_async_commit_group() + load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) + nvvm.cp_async_commit_group() + + # Q/K normalization - tokens 0, 1 always + q_head0 = gQ[(batch_idx, 0, query_head_idx, None)] + k_head0 = gK[(batch_idx, 0, query_head_idx, None)] + q_head1 = gQ[(batch_idx, 1, query_head_idx, None)] + k_head1 = gK[(batch_idx, 1, query_head_idx, None)] + + if warp_idx == 0: + normalize_and_store_qk_to_smem( + q_head0, k_head0, q_sh0, k_sh0, lane_idx, scale, eps + ) + if warp_idx == 1: + normalize_and_store_qk_to_smem( + q_head1, k_head1, q_sh1, k_sh1, lane_idx, scale, eps + ) + + # Token 2 Q/K normalization - only for NUM_TOKENS >= 3 + if NUM_TOKENS >= 3: + q_head2 = gQ[(batch_idx, 2, query_head_idx, None)] + k_head2 = gK[(batch_idx, 2, query_head_idx, None)] + if warp_idx == 2: + normalize_and_store_qk_to_smem( + q_head2, k_head2, q_sh2, k_sh2, lane_idx, scale, eps + ) + + # Token 3 Q/K normalization - only for NUM_TOKENS = 4 + if NUM_TOKENS == 4: + q_head3 = gQ[(batch_idx, 3, query_head_idx, None)] + k_head3 = gK[(batch_idx, 3, query_head_idx, None)] + if warp_idx == 3: + normalize_and_store_qk_to_smem( + q_head3, k_head3, q_sh3, k_sh3, lane_idx, scale, eps + ) + + cute.arch.sync_threads() + + # V loading - tokens 0, 1 always + v_head0 = gV[(batch_idx, 0, value_head_idx, None)] + v_head1 = gV[(batch_idx, 1, value_head_idx, None)] + load_v_to_smem(v_head0, v_sh0, tidx) + load_v_to_smem(v_head1, v_sh1, tidx) + + # Token 2 V loading - only for NUM_TOKENS >= 3 + if NUM_TOKENS >= 3: + v_head2 = gV[(batch_idx, 2, value_head_idx, None)] + load_v_to_smem(v_head2, v_sh2, tidx) + + # Token 3 V loading - only for NUM_TOKENS = 4 + if NUM_TOKENS == 4: + v_head3 = gV[(batch_idx, 3, value_head_idx, None)] + load_v_to_smem(v_head3, v_sh3, tidx) + + # Output pointers - tokens 0, 1 always + h_out = gH[(batch_idx, value_head_idx, None, None)] + o_head0 = gO[(batch_idx, 0, value_head_idx, None)] + o_head1 = gO[(batch_idx, 1, value_head_idx, None)] + + # Token 2 output pointer + o_head2 = o_head1 # Default for T=2 + if NUM_TOKENS >= 3: + o_head2 = gO[(batch_idx, 2, value_head_idx, None)] + + # Token 3 output pointer + o_head3 = o_head2 # Default for T=2,3 + if NUM_TOKENS == 4: + o_head3 = gO[(batch_idx, 3, value_head_idx, None)] + + # Process V-CHUNK 0 + nvvm.cp_async_wait_group(3) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk0, + h_sh_chunk0, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 0, + 0, + cutlass.Int32(0), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Process V-CHUNK 1 + nvvm.cp_async_wait_group(2) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk1, + h_sh_chunk0, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 32, + 0, + cutlass.Int32(1), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Process V-CHUNK 2 + nvvm.cp_async_wait_group(1) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk2, + h_sh_chunk1, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 64, + 32, + cutlass.Int32(1), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Process V-CHUNK 3 + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + process_vchunk_unified_234( + h_sh_chunk3, + h_sh_chunk2, + h_out, + h_chunk, + kq_chunk, + k_sh0, + k_sh1, + k_sh2, + k_sh3, + q_sh0, + q_sh1, + q_sh2, + q_sh3, + v_sh0, + v_sh1, + v_sh2, + v_sh3, + reduce_sh, + o_head0, + o_head1, + o_head2, + o_head3, + g_exp0, + g_exp1, + g_exp2, + g_exp3, + beta0, + beta1, + beta2, + beta3, + 96, + 64, + cutlass.Int32(1), + tidx, + warp_idx, + lane_idx, + k_base, + NUM_TOKENS, + ) + + # Final H store + cute.arch.sync_threads() + store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) + + +# ============================================================================== +# LAUNCH WRAPPERS +# ============================================================================== + + +@cute.jit +def gated_delta_rule_launch_seqlen1( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen1( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +# ============================================================================== +# LOW-BS SEQLEN=1 KERNEL - 1 V-CHUNK PER CTA (T=1, BS<=4) +# ============================================================================== + + +@cute.kernel +def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( + gQ: cute.Tensor, + gK: cute.Tensor, + gV: cute.Tensor, + ga: cute.Tensor, + gb: cute.Tensor, + gA_log: cute.Tensor, + gdt_bias: cute.Tensor, + gH: cute.Tensor, + gO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, +): + """ + Seqlen=1 kernel with 1 V-chunk (32 V rows) per CTA. + For T=1, batch_size <= 4: more CTAs per batch*head for better SM utilization. + Grid: batch_idx * HV * 4 + value_head_idx * 4 + v_chunk_idx (0..3). + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + HV = cutlass.Int32(gV.shape[2]) + H = cutlass.Int32(gQ.shape[2]) + + batch_idx = bidx // (HV * 4) + remainder = bidx % (HV * 4) + value_head_idx = remainder // 4 + v_chunk_idx = remainder % 4 + + query_head_idx = value_head_idx // (HV // H) + v_row_base = v_chunk_idx * 32 + + smem = utils.SmemAllocator() + + alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) + A_log_val = gA_log[value_head_idx] + dt_bias_val = gdt_bias[value_head_idx] + g_exp, beta = compute_single_gate( + alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + ) + + h_sh_chunk = smem.allocate_tensor( + cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + ) + + q_sh = smem.allocate_tensor(cutlass.Float32, 128) + k_sh = smem.allocate_tensor(cutlass.Float32, 128) + + pred_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + out_sh = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + ) + + h_global = gH[(batch_idx, value_head_idx, None, None)] + + load_h_chunk_async(h_sh_chunk, h_global, tidx, v_row_base) + nvvm.cp_async_commit_group() + + q_head = gQ[(batch_idx, 0, query_head_idx, None)] + k_head = gK[(batch_idx, 0, query_head_idx, None)] + + warp_idx = tidx // 32 + lane_idx = tidx % 32 + + if warp_idx == 0: + normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) + + cute.arch.sync_threads() + + v_head = gV[(batch_idx, 0, value_head_idx, None)] + v_sh = smem.allocate_tensor(cutlass.Float32, 32) + if tidx < 32: + v_sh[tidx] = v_head[v_row_base + tidx].to(cutlass.Float32) + + h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) + qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) + + k_base = warp_idx * 32 + + for i in cutlass.range_constexpr(32): + k_chunk[i] = k_sh[k_base + i] + + h_out = gH[(batch_idx, value_head_idx, None, None)] + o_head = gO[(batch_idx, 0, value_head_idx, None)] + + nvvm.cp_async_wait_group(0) + cute.arch.sync_threads() + + pred = cutlass.Float32(0.0) + pred2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + h_sh_chunk[lane_idx, k_base + i].to(cutlass.Float32), + h_sh_chunk[lane_idx, k_base + i + 1].to(cutlass.Float32), + ), + src_b=(g_exp, g_exp), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + for i in cutlass.range_constexpr(0, 32, 2): + pred, pred2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(k_chunk[i], k_chunk[i + 1]), + src_c=(pred, pred2), + ) + pred = pred + pred2 + + pred_sh[lane_idx, warp_idx] = pred + cute.arch.sync_threads() + pred_final = ( + pred_sh[lane_idx, 0] + + pred_sh[lane_idx, 1] + + pred_sh[lane_idx, 2] + + pred_sh[lane_idx, 3] + ) + + v_val = (v_sh[lane_idx] - pred_final) * beta + + for i in cutlass.range_constexpr(0, 32, 2): + h_chunk[i], h_chunk[i + 1] = cute.arch.fma_packed_f32x2( + src_a=(k_chunk[i], k_chunk[i + 1]), + src_b=(v_val, v_val), + src_c=(h_chunk[i], h_chunk[i + 1]), + ) + + for i in cutlass.range_constexpr(32): + qk_temp[i] = q_sh[k_base + i] + + out = cutlass.Float32(0.0) + out2 = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(0, 32, 2): + out, out2 = cute.arch.fma_packed_f32x2( + src_a=(h_chunk[i], h_chunk[i + 1]), + src_b=(qk_temp[i], qk_temp[i + 1]), + src_c=(out, out2), + ) + out = out + out2 + + out_sh[lane_idx, warp_idx] = out + cute.arch.sync_threads() + out_final = ( + out_sh[lane_idx, 0] + + out_sh[lane_idx, 1] + + out_sh[lane_idx, 2] + + out_sh[lane_idx, 3] + ) + + write_h_chunk_to_smem(h_chunk, h_sh_chunk, lane_idx, k_base) + if warp_idx == 0: + o_head[v_row_base + lane_idx] = out_final.to(cutlass.BFloat16) + + cute.arch.sync_threads() + store_h_smem_to_gmem(h_sh_chunk, h_out, tidx, v_row_base) + + +@cute.jit +def gated_delta_rule_launch_seqlen1_lowBS_1chunk( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + """Launch LowBS-1 kernel: 4 CTAs per (batch, value_head).""" + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + ).launch( + grid=[batch_size * HV * 4, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +@cute.jit +def gated_delta_rule_launch_seqlen2( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen234_unified( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + 2, # NUM_TOKENS=2 + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +@cute.jit +def gated_delta_rule_launch_seqlen3( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen234_unified( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + 3, # NUM_TOKENS=3 + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +@cute.jit +def gated_delta_rule_launch_seqlen4( + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + ma: cute.Tensor, + mb: cute.Tensor, + mA_log: cute.Tensor, + mdt_bias: cute.Tensor, + mH: cute.Tensor, + mO: cute.Tensor, + scale: cutlass.Float32, + softplus_beta: cutlass.Float32, + softplus_threshold: cutlass.Float32, + eps: cutlass.Float32, + stream: cuda.CUstream, +): + batch_size = mQ.shape[0] + HV = mV.shape[2] + + gated_delta_rule_decode_kernel_seqlen234_unified( + mQ, + mK, + mV, + ma, + mb, + mA_log, + mdt_bias, + mH, + mO, + scale, + softplus_beta, + softplus_threshold, + eps, + 4, # NUM_TOKENS=4 + ).launch( + grid=[batch_size * HV, 1, 1], + block=[128, 1, 1], + stream=stream, + ) + + +# ============================================================================== +# KERNEL CLASS +# ============================================================================== + + +class GatedDeltaRuleKernel: + """ + Gated Delta Rule Kernel for linear attention decode. + + This kernel implements the Gated Delta Rule mechanism supporting sequence + lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations. + + Key features: + - T=1: Persistent K in registers with aggressive pipelining + - T=2/3/4: Unified kernel with compile-time Constexpr specialization + - L2-normalized Q/K with configurable scale + - Gated exponential decay via softplus + - Bank-conflict-free cross-warp reductions + - Async H memory loading + + Args: + seq_len: Sequence length (1, 2, 3, or 4) + """ + + def __init__(self, seq_len: int): + assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}" + self.seq_len = seq_len + self._compiled_kernel = None + + def _get_launch_fn(self): + if self.seq_len == 1: + return gated_delta_rule_launch_seqlen1 + elif self.seq_len == 2: + return gated_delta_rule_launch_seqlen2 + elif self.seq_len == 3: + return gated_delta_rule_launch_seqlen3 + else: + return gated_delta_rule_launch_seqlen4 + + +# ============================================================================== +# PUBLIC API +# ============================================================================== + +_compiled_kernels = {} # Cache: (seqlen, batch_size) -> compiled kernel + + +def gated_delta_rule( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + q: Optional[torch.Tensor] = None, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + b: Optional[torch.Tensor] = None, + initial_state_source: Optional[torch.Tensor] = None, + initial_state_indices: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = True, + scale: Optional[float] = None, +) -> torch.Tensor: + """ + Gated Delta Rule linear attention kernel. + + Implements the Gated Delta Rule mechanism for decode-phase inference, + supporting sequence lengths T=1, T=2, T=3, T=4. + + Args: + A_log: Log decay parameter [HV] + a: Alpha gate input [B, T, HV] + dt_bias: Delta-t bias [HV] + softplus_beta: Softplus beta parameter (default: 1.0) + softplus_threshold: Softplus threshold (default: 20.0) + q: Query tensor [B, T, H, K] + k: Key tensor [B, T, H, K] + v: Value tensor [B, T, HV, V] + b: Beta gate input [B, T, HV] + initial_state_source: H state [B, HV, V, K] (K-fast layout), modified in-place + initial_state_indices: Not used (for compatibility) + use_qk_l2norm_in_kernel: Whether to L2-normalize Q/K in kernel (default: True) + scale: Optional attention scale (default: 1/sqrt(K)) + + Returns: + output: [B, T, HV, V] + + Example: + >>> B, T, H, K = 16, 1, 16, 128 + >>> HV, V = 32, 128 + >>> q = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) + >>> k = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) + >>> v = torch.randn(B, T, HV, V, device='cuda', dtype=torch.bfloat16) + >>> a = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) + >>> b = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) + >>> A_log = torch.randn(HV, device='cuda', dtype=torch.float32) + >>> dt_bias = torch.randn(HV, device='cuda', dtype=torch.float32) + >>> h_state = torch.randn(B, HV, V, K, device='cuda', dtype=torch.bfloat16) + >>> output = gated_delta_rule( + ... A_log, a, dt_bias, q=q, k=k, v=v, b=b, + ... initial_state_source=h_state + ... ) + """ + global _compiled_kernels + + # Validate required Optional parameters + if q is None: + raise ValueError("q (query tensor) is required") + if k is None: + raise ValueError("k (key tensor) is required") + if v is None: + raise ValueError("v (value tensor) is required") + if b is None: + raise ValueError("b (beta gate tensor) is required") + if initial_state_source is None: + raise ValueError("initial_state_source (H state tensor) is required") + + B, T, H, K = q.shape + assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" + HV = v.shape[2] + V = v.shape[3] + + if scale is None: + scale = 1.0 / math.sqrt(K) + + output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + + q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) + k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) + v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) + a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) + b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) + A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) + dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) + h_ = from_dlpack(initial_state_source, assumed_align=32, enable_tvm_ffi=True) + o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) + + scale_f32 = cutlass.Float32(scale) + softplus_beta_f32 = cutlass.Float32(softplus_beta) + softplus_threshold_f32 = cutlass.Float32(softplus_threshold) + eps_f32 = cutlass.Float32(1e-6) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Check cache - include all shape dimensions to avoid incorrect reuse + cache_key = (T, B, H, HV, K, V) + if cache_key not in _compiled_kernels: + # Select and compile the appropriate kernel + if T == 1 and B <= 4: + launch_fn = gated_delta_rule_launch_seqlen1_lowBS_1chunk + elif T == 1: + launch_fn = gated_delta_rule_launch_seqlen1 + elif T == 2: + launch_fn = gated_delta_rule_launch_seqlen2 + elif T == 3: + launch_fn = gated_delta_rule_launch_seqlen3 + else: # T == 4 + launch_fn = gated_delta_rule_launch_seqlen4 + + _compiled_kernels[cache_key] = cute.compile( + launch_fn, + q_, + k_, + v_, + a_, + b_, + A_log_, + dt_bias_, + h_, + o_, + scale_f32, + softplus_beta_f32, + softplus_threshold_f32, + eps_f32, + stream, + options="--enable-tvm-ffi --generate-line-info", + ) + + # Execute + _compiled_kernels[cache_key]( + q_, + k_, + v_, + a_, + b_, + A_log_, + dt_bias_, + h_, + o_, + scale_f32, + softplus_beta_f32, + softplus_threshold_f32, + eps_f32, + stream, + ) + + return output diff --git a/tests/gdn/reference_delta_rule.py b/tests/gdn/reference_delta_rule.py index 3fa10e0a2d..7296610bbd 100644 --- a/tests/gdn/reference_delta_rule.py +++ b/tests/gdn/reference_delta_rule.py @@ -136,7 +136,7 @@ def blockwise_linear_attention( decay_factor: float | torch.Tensor = 1.0, # float or tensor with num_elems == num_qo_heads decay_exponent_offset=0, - kv_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: num_qo_heads = q.size(1) head_size = q.size(2) @@ -156,7 +156,7 @@ def blockwise_linear_attention( KVs = [] # FIXME: kernel debug only kv = torch.zeros( (len(seq_lens), num_qo_heads, head_size, head_size), - dtype=kv_dtype, + dtype=state_dtype, device=q.device, ) output = torch.zeros_like(q) @@ -166,7 +166,7 @@ def blockwise_linear_attention( seq_end = seq_offset[seq_idx + 1] blk_offset = seq_start carried_kv = torch.zeros( - (num_qo_heads, head_size, head_size), dtype=kv_dtype, device=q.device + (num_qo_heads, head_size, head_size), dtype=state_dtype, device=q.device ) while blk_offset < seq_end: is_full_block = seq_end - blk_offset >= block_size @@ -205,7 +205,10 @@ def blockwise_linear_attention( ) o_inter = ( - matmul(q_t.transpose(0, 1).to(kv_dtype) * Lq, carried_kv) + matmul( + q_t.transpose(0, 1).to(torch.float32) * Lq, + carried_kv.to(torch.float32), + ) .transpose(0, 1) .to(q.dtype) ) @@ -219,10 +222,10 @@ def blockwise_linear_attention( if (decay_factor == 1.0).all(): inc_kv = matmul( - k_t.transpose(0, 1).transpose(-2, -1).to(kv_dtype), - v_t.transpose(0, 1).to(kv_dtype), + k_t.transpose(0, 1).transpose(-2, -1).to(torch.float32), + v_t.transpose(0, 1).to(torch.float32), ) - carried_kv = carried_kv + inc_kv + carried_kv = (carried_kv.to(torch.float32) + inc_kv).to(state_dtype) else: Lk = LambdaK( decay_factor, @@ -232,11 +235,13 @@ def blockwise_linear_attention( offset=decay_exponent_offset, ) inc_kv = matmul( - (k_t.transpose(0, 1) * Lk).transpose(-2, -1).to(kv_dtype), - v_t.transpose(0, 1).to(kv_dtype), + (k_t.transpose(0, 1) * Lk).transpose(-2, -1).to(torch.float32), + v_t.transpose(0, 1).to(torch.float32), ) block_decay = decay_factor**valid_len - carried_kv = block_decay * carried_kv + inc_kv + carried_kv = (block_decay * carried_kv.to(torch.float32) + inc_kv).to( + state_dtype + ) KVs.append(carried_kv.clone()) blk_offset += block_size @@ -256,7 +261,7 @@ def delta_rule( alpha: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] beta: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] scale_factor=1.0, - kv_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.float32, ): o = [] kv = [] @@ -297,7 +302,7 @@ def delta_rule( betas = beta[s] state_HKV = torch.zeros( - num_q_heads, head_size, head_size, dtype=kv_dtype, device=q.device + num_q_heads, head_size, head_size, dtype=state_dtype, device=q.device ) for i in range(seq_len): # var_DS where var is variable basename and DS is the dimensional semantics. @@ -311,14 +316,15 @@ def delta_rule( ### listed at the bottom of page3 of section 2.2 DELTA NETWORKS: LINEAR ATTENTION WITH DELTA RULE # state update rule, use the middle version for clearer dimensional semantics - old_state_HKV = alpha_H11 * state_HKV + # Read state in fp32, compute in fp32, store back in state_dtype + old_state_HKV = alpha_H11 * state_HKV.to(torch.float32) old_v_H1V = matmul(k_H1K, old_state_HKV) new_v_H1V = beta_H11 * v_H1V + (1 - beta_H11) * old_v_H1V state_remove = torch.einsum("htv,htk->hkv", old_v_H1V, k_H1K) state_update = torch.einsum("htv,htk->hkv", new_v_H1V, k_H1K) - state_HKV[:] = old_state_HKV - state_remove + state_update + state_HKV[:] = (old_state_HKV - state_remove + state_update).to(state_dtype) - o_H1V = scale_factor * matmul(q_H1Q, state_HKV) + o_H1V = scale_factor * matmul(q_H1Q, state_HKV.to(torch.float32)) o.append(o_H1V.squeeze(1)) kv.append(state_HKV.clone()) @@ -356,7 +362,7 @@ def blockwise_delta_rule( beta: torch.Tensor | None = None, # [total_seq_len, num_qo_heads] block_size: int = 32, scale_factor=1.0, - kv_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.float32, # intermediate_outputs = None, # debug output ) -> torch.Tensor: total_seqlen = q.size(0) @@ -386,7 +392,7 @@ def blockwise_delta_rule( kv = torch.zeros( (len(seq_lens), num_sab_heads, head_size, head_size), - dtype=kv_dtype, + dtype=state_dtype, device=q.device, ) output = torch.zeros_like(q) @@ -396,7 +402,7 @@ def blockwise_delta_rule( seq_end = seq_offset[seq_idx + 1] blk_offset = seq_start state_HKV = torch.zeros( - (num_sab_heads, head_size, head_size), dtype=kv_dtype, device=q.device + (num_sab_heads, head_size, head_size), dtype=state_dtype, device=q.device ) while blk_offset < seq_end: is_full_block = seq_end - blk_offset >= block_size @@ -455,7 +461,9 @@ def blockwise_delta_rule( # new_v_HSV = matmul(T, (v_HSV - matmul(torch.exp(gamma_HS1) * k_HSK, state_HKV))) u_HSV = matmul(T, v_HSV) w_HSK = matmul(T, torch.exp(gamma_HS1) * k_HSK) - new_v_HSV = u_HSV - matmul(w_HSK.to(kv_dtype), state_HKV).to(u_HSV.dtype) + new_v_HSV = u_HSV - matmul( + w_HSK.to(torch.float32), state_HKV.to(torch.float32) + ).to(u_HSV.dtype) new_v_SHV = new_v_HSV.transpose(0, 1) # if intermediate_outputs is not None: @@ -468,7 +476,10 @@ def blockwise_delta_rule( # intermediate_outputs["new_v"].append(new_v_HSV.clone()) o_inter = ( - matmul(torch.exp(gamma_HS1) * q_HSQ.to(kv_dtype), state_HKV) + matmul( + torch.exp(gamma_HS1) * q_HSQ.to(torch.float32), + state_HKV.to(torch.float32), + ) .transpose(0, 1) .to(q.dtype) ) @@ -484,10 +495,12 @@ def blockwise_delta_rule( inc_HKV = matmul( (torch.exp(block_gamma - gamma_HS1) * k_HSK) .transpose(-2, -1) - .to(kv_dtype), - new_v_HSV.to(kv_dtype), + .to(torch.float32), + new_v_HSV.to(torch.float32), ) - state_HKV = torch.exp(block_gamma) * state_HKV + inc_HKV + state_HKV = ( + torch.exp(block_gamma) * state_HKV.to(torch.float32) + inc_HKV + ).to(state_dtype) blk_offset += block_size @@ -510,6 +523,7 @@ def decode_delta_rule( softplus_beta: float = 1.0, softplus_threshold: float = 20.0, use_l2_norm: bool = True, + state_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: """ Reference implementation for single-step decode with GDN formula. @@ -537,6 +551,7 @@ def decode_delta_rule( softplus_beta: Beta parameter for softplus activation softplus_threshold: Threshold for softplus numerical stability use_l2_norm: Whether to apply L2 normalization to q and k + state_dtype: Storage dtype for the hidden state (read in fp32, stored in this dtype) Returns: output: [B, num_heads, V] @@ -617,7 +632,7 @@ def decode_delta_rule( # ============================================ # Process each batch and head # ============================================ - new_state = torch.zeros(B, num_heads, K, V, device=device, dtype=dtype) + new_state = torch.zeros(B, num_heads, K, V, device=device, dtype=state_dtype) output = torch.zeros(B, num_heads, V, device=device, dtype=dtype) for b_idx in range(B): @@ -626,7 +641,9 @@ def decode_delta_rule( q_h = q[b_idx, h_idx] # [K] k_h = k[b_idx, h_idx] # [K] v_h = v[b_idx, h_idx] # [V] - h_state = state[b_idx, h_idx].clone() # [K, V] (matches Triton's [BK, BV]) + h_state = ( + state[b_idx, h_idx].clone().to(torch.float32) + ) # [K, V] read as fp32 # Get gating values for this batch and head g_val = g[b_idx, h_idx] # scalar @@ -673,8 +690,8 @@ def decode_delta_rule( # [K] @ [K, V] = [V] output[b_idx, h_idx] = q_h @ h_state - # Store updated state - new_state[b_idx, h_idx] = h_state + # Store updated state (cast back to state_dtype) + new_state[b_idx, h_idx] = h_state.to(state_dtype) return output, new_state @@ -694,6 +711,7 @@ def verify_delta_rule( softplus_threshold: float = 20.0, use_l2_norm: bool = True, cache_intermediate_states: bool = False, + state_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Reference implementation for multi-token (verify mode) delta rule. @@ -715,6 +733,7 @@ def verify_delta_rule( softplus_threshold: Threshold for softplus approximation use_l2_norm: Whether to apply L2 normalization cache_intermediate_states: Whether to cache state at each time step + state_dtype: Storage dtype for the hidden state (read in fp32, stored in this dtype) Returns: output: Output tensor [B, T, num_heads, V] @@ -779,11 +798,13 @@ def verify_delta_rule( # Initialize output and intermediate states output = torch.zeros(B, T, num_heads, V, dtype=torch.float32, device=q.device) - current_state = state.clone() # [B, num_heads, K, V] + current_state = state.clone().to( + state_dtype + ) # [B, num_heads, K, V] stored in state_dtype if cache_intermediate_states: intermediate_states = torch.zeros( - B, T, num_heads, K, V, dtype=torch.float32, device=q.device + B, T, num_heads, K, V, dtype=state_dtype, device=q.device ) else: intermediate_states = None @@ -802,7 +823,9 @@ def verify_delta_rule( q_h = q_t[b_idx, h_idx] # [K] k_h = k_t[b_idx, h_idx] # [K] v_h = v_t[b_idx, h_idx] # [V] - h_state = current_state[b_idx, h_idx].clone() # [K, V] + h_state = ( + current_state[b_idx, h_idx].clone().to(torch.float32) + ) # [K, V] read as fp32 g_val = g_t[b_idx, h_idx] beta_val = beta_t[b_idx, h_idx] @@ -825,11 +848,11 @@ def verify_delta_rule( # 5. Compute output: o = q^T @ h output[b_idx, t, h_idx] = q_h @ h_state # [K] @ [K, V] = [V] - # Update current state - current_state[b_idx, h_idx] = h_state + # Update current state (cast back to state_dtype) + current_state[b_idx, h_idx] = h_state.to(state_dtype) # Cache intermediate state if requested if cache_intermediate_states: - intermediate_states[b_idx, t, h_idx] = h_state + intermediate_states[b_idx, t, h_idx] = h_state.to(state_dtype) return output, current_state, intermediate_states diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 7d93b9ce98..963198c8a6 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -41,6 +41,16 @@ ) from flashinfer.utils import get_compute_capability +# Import the gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last layout) +try: + from flashinfer.gdn_kernels.gdn_decode_bf16_state import ( + gated_delta_rule as gdn_decode_klast_bf16_state, + ) + + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True +except ImportError: + GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False + def _skip_if_not_sm90_or_later(): """Skip test if not Hopper (SM90+) or Blackwell (SM100+) architecture.""" @@ -51,6 +61,7 @@ def _skip_if_not_sm90_or_later(): # ============================================================================ # Test decode kernel with pretranspose version ([B*HV, V, K]) +# Reference: fp32 h state (default); bf16 h state used only for gdn_decode_klast_bf16_state. # ============================================================================ @@ -149,14 +160,12 @@ def _test_decode_kernel_pretranspose( # Remove T dimension for comparison: [B, 1, H, D] -> [B, H, D] our_o = our_o.squeeze(1) - # Reference implementation (remove T=1 dimension) - # Now passes raw GDN parameters, will compute g and beta internally - # Reference uses [B, HV, K, V] state (matches Triton) + # Reference: fp32 h state (default state_dtype) ref_o, ref_state = decode_delta_rule( q.squeeze(1).float(), # [B, 1, H, K] -> [B, H, K] k.squeeze(1).float(), v.squeeze(1).float(), - input_state_ref, # Use [B, HV, K, V] state for reference + input_state_ref, # [B, HV, K, V] A_log=A_log, a=a.squeeze(1), # Remove T dimension: [B, 1, HV] -> [B, HV] dt_bias=dt_bias, @@ -223,6 +232,7 @@ def test_decode_kernel_basic_pretranspose( # ============================================================================ # Test decode kernel with nontranspose version ([pool, HV, K, V]) +# Reference: fp32 h state (default). # ============================================================================ @@ -315,13 +325,12 @@ def _test_decode_kernel_nontranspose( # Remove T dimension for comparison: [B, 1, H, D] -> [B, H, D] our_o = our_o.squeeze(1) - # Reference implementation (remove T=1 dimension) - # Reference uses [B, HV, K, V] state (matches both Triton and nontranspose kernel) + # Reference: fp32 h state (default state_dtype) ref_o, ref_state = decode_delta_rule( q.squeeze(1).float(), # [B, 1, H, K] -> [B, H, K] k.squeeze(1).float(), v.squeeze(1).float(), - input_state, # Use [B, HV, K, V] state for reference + input_state, # [B, HV, K, V] A_log=A_log, a=a.squeeze(1), # Remove T dimension: [B, 1, HV] -> [B, HV] dt_bias=dt_bias, @@ -388,6 +397,7 @@ def test_decode_kernel_basic_nontranspose( # ============================================================================ # Test verify kernel with MTP version (Multiple Token Processing) +# Reference: fp32 h state (default). # ============================================================================ @@ -602,6 +612,313 @@ def test_verify_kernel_mtp( ) +# ============================================================================ +# Test gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last) +# Reference: bf16 h state only here (state_dtype=torch.bfloat16). Other kernels +# above use fp32 h state reference. +# ============================================================================ + + +def _test_gdn_decode_klast_bf16_state_kernel( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + seq_len: int, # T=1,2,3,4 + scale: float, + alpha: bool, + beta: bool, + seed: int | None = None, +): + """Test gdn_decode_klast_bf16_state kernel for T=1,2,3,4 with bf16 h state. + + Both kernel and reference use bf16 h state: reference runs with + state_dtype=torch.bfloat16 (read h as fp32, compute in fp32, store h in bf16) + so the comparison is apples-to-apples with the gdn_decode_klast_bf16_state kernel. + """ + _skip_if_not_sm90_or_later() + + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + pytest.skip("gdn_decode_klast_bf16_state kernel not available") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + assert seq_len in [1, 2, 3, 4], ( + f"gdn_decode_klast_bf16_state supports T=1,2,3,4, got T={seq_len}" + ) + + # State and GDN parameters are based on num_v_heads (HV in kernel API) + num_sab_heads = num_v_heads + + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + # Generate inputs with T dimension + q = torch.randn(batch_size, seq_len, num_q_heads, head_size, dtype=dtype_torch) + k = torch.randn(batch_size, seq_len, num_k_heads, head_size, dtype=dtype_torch) + v = torch.randn(batch_size, seq_len, num_v_heads, head_size, dtype=dtype_torch) + + # NOTE: Do NOT pre-normalize K here. Both the kernel (use_qk_l2norm_in_kernel=True) + # and reference will apply L2 normalization internally after GQA expansion. + + # gdn_decode_klast_bf16_state kernel expects [B, HV, V, K] (K-fast layout) in BF16. + # Use the same bf16 initial state for both kernel and reference so we + # compare the bf16 h state path. + input_state_kernel = torch.randn( + batch_size, num_sab_heads, head_size, head_size, dtype=torch.bfloat16 + ) + + # Reference uses [B, HV, K, V] layout; same bf16 values as kernel. + input_state_ref_bf16 = input_state_kernel.transpose(-2, -1).contiguous() + + # Create GDN-specific parameters + # A_log: log decay parameter [HV] - must be float32 + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + + # dt_bias: decay bias [HV] - must be float32 for gdn_decode_klast_bf16_state kernel + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + + # a: input-dependent decay [B, T, HV] + a = ( + torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + * 0.1 + ) + + # b: update gate input [B, T, HV] + if beta: + b_tensor = torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + else: + b_tensor = ( + torch.ones( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + * 10.0 + ) + + # Call gdn_decode_klast_bf16_state kernel + our_state = input_state_kernel.clone() + our_o = gdn_decode_klast_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b_tensor, + initial_state_source=our_state, + use_qk_l2norm_in_kernel=True, + scale=scale, + ) + + torch.cuda.synchronize() + + # Reference implementation with bf16 h state (state_dtype=torch.bfloat16): + # h is stored in bf16, read as fp32 for computation, written back in bf16. + ref_state = input_state_ref_bf16.clone() + ref_outputs = [] + + for t in range(seq_len): + ref_o_t, ref_state = decode_delta_rule( + q[:, t].float(), # [B, H, K] + k[:, t].float(), + v[:, t].float(), + ref_state, # [B, HV, K, V] bf16 + A_log=A_log, + a=a[:, t], # [B, HV] + dt_bias=dt_bias, + b=b_tensor[:, t], # [B, HV] + scale_factor=scale, + softplus_beta=1.0, + softplus_threshold=20.0, + use_l2_norm=True, + state_dtype=torch.bfloat16, # match kernel: h stored in bf16 + ) + ref_outputs.append(ref_o_t) + + # Stack reference outputs: [B, T, HV, V] + ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch) + + # Tolerances for bf16 h state comparison + atol_o = 0.001 + rtol_o = 0.005 + atol_kv = 0.005 + rtol_kv = 0.005 + + # Compare outputs + torch.testing.assert_close( + our_o.float(), + ref_o.float(), + atol=atol_o, + rtol=rtol_o, + msg=f"Output mismatch for gdn_decode_klast_bf16_state kernel (B={batch_size}, T={seq_len})", + ) + + # Compare states: both in bf16 (kernel [B, HV, V, K], ref [B, HV, K, V]) + ref_state_transposed = ref_state.transpose(-2, -1).contiguous() + torch.testing.assert_close( + our_state.float(), + ref_state_transposed.float(), + atol=atol_kv, + rtol=rtol_kv, + msg=f"State mismatch for gdn_decode_klast_bf16_state kernel (B={batch_size}, T={seq_len})", + ) + + print( + f"āœ“ gdn_decode_klast_bf16_state kernel test passed (batch={batch_size}, T={seq_len}, dtype={dtype}, h_state=bf16)" + ) + + +@pytest.mark.parametrize("beta", [True]) +@pytest.mark.parametrize("alpha", [True]) +@pytest.mark.parametrize("scale", ["auto"]) # Use 1/sqrt(K) like compare_flashinfer.py +@pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + "num_q_heads, num_k_heads, num_v_heads", + [(16, 16, 32)], +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_gdn_decode_klast_bf16_state_kernel( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + seq_len: int, + scale: float | str, + alpha: bool, + beta: bool, + seed: int = int(os.environ.get("SEED", "0")), +): + scale_val = 1.0 / math.sqrt(head_size) if scale == "auto" else scale + _test_gdn_decode_klast_bf16_state_kernel( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + seq_len, + scale_val, + alpha, + beta, + seed, + ) + + +@pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + "num_q_heads, num_k_heads, num_v_heads", + [(16, 16, 32)], +) +def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + seq_len: int, + seed: int = int(os.environ.get("SEED", "0")), +): + """Verify gated_delta_rule_decode_pretranspose dispatches to gdn_decode_klast_bf16_state when state is bf16 and T<=4, K=V=128. + + Calls the API with bf16 state and checks output/state match the direct gdn_decode_klast_bf16_state call. + """ + _skip_if_not_sm90_or_later() + if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + pytest.skip("gdn_decode_klast_bf16_state kernel not available") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + dtype = torch.bfloat16 + device = torch.device("cuda") + scale = 1.0 / math.sqrt(head_size) + num_sab_heads = num_v_heads + + q = torch.randn( + batch_size, seq_len, num_q_heads, head_size, dtype=dtype, device=device + ) + k = torch.randn( + batch_size, seq_len, num_k_heads, head_size, dtype=dtype, device=device + ) + v = torch.randn( + batch_size, seq_len, num_v_heads, head_size, dtype=dtype, device=device + ) + a = ( + torch.randn(batch_size, seq_len, num_sab_heads, dtype=dtype, device=device) + * 0.1 + ) + b_tensor = torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype, device=device + ) + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + + # State [B, HV, V, K] in bf16 (Qwen-style K-last) so API uses improved backend + state_api = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, + device=device, + ) + state_direct = state_api.clone() + + # Via API (should dispatch to gdn_decode_klast_bf16_state) + out_api, state_api = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=state_api, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b_tensor, + scale=scale, + use_qk_l2norm=True, + ) + + # Direct improved kernel + out_direct = gdn_decode_klast_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b_tensor, + initial_state_source=state_direct, + use_qk_l2norm_in_kernel=True, + scale=scale, + ) + + torch.testing.assert_close(out_api, out_direct, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(state_api, state_direct, atol=1e-2, rtol=1e-2) + print( + f"āœ“ API gdn_decode_klast_bf16_state backend verified (batch={batch_size}, T={seq_len})" + ) + + if __name__ == "__main__": print("Running smoke tests...") print("\n=== Testing PRETRANSPOSE version ===") @@ -648,15 +965,37 @@ def test_verify_kernel_mtp( seed=42, ) + print("\n=== Testing IMPROVED CuTe-DSL version (T=1,2,3,4) ===") + if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + for t in [1, 2, 3, 4]: + _test_gdn_decode_klast_bf16_state_kernel( + dtype="bfloat16", + batch_size=4, + num_q_heads=16, + num_k_heads=16, + num_v_heads=32, + head_size=128, + seq_len=t, + scale=1.0, + alpha=True, + beta=True, + seed=42, + ) + else: + print("⚠ gdn_decode_klast_bf16_state kernel not available, skipping...") + print("\nāœ… All smoke tests passed!") print("\nTo run full test suite:") print( - " PRETRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_pretranspose -v" + " PRETRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_pretranspose -v" + ) + print( + " NONTRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_nontranspose -v" ) print( - " NONTRANSPOSE: pytest test_decode_delta_rule.py::test_decode_kernel_basic_nontranspose -v" + " MTP (VERIFY): pytest test_decode_delta_rule.py::test_verify_kernel_mtp -v" ) print( - " MTP (VERIFY): pytest test_decode_delta_rule.py::test_verify_kernel_mtp -v" + " gdn_decode_klast_bf16_state: pytest test_decode_delta_rule.py::test_gdn_decode_klast_bf16_state_kernel -v" ) print(" ALL: pytest test_decode_delta_rule.py -v") diff --git a/tests/gdn/test_prefill_delta_rule.py b/tests/gdn/test_prefill_delta_rule.py index f2fd06cbce..8bc87f2ef6 100644 --- a/tests/gdn/test_prefill_delta_rule.py +++ b/tests/gdn/test_prefill_delta_rule.py @@ -117,7 +117,7 @@ def _test_prefill_kernel( scale_factor=scale, alpha=alpha, beta=beta, - kv_dtype=torch.float32, + state_dtype=torch.float32, ) ref_o = ref_o.to(q.dtype) ref_state = ref_state.to(kv_dtype) @@ -364,7 +364,7 @@ def concat_varlen(t1, cu_seq_lens1, t2, cu_seq_lens2): scale_factor=scale, alpha=alpha, beta=beta, - kv_dtype=torch.float32, + state_dtype=torch.float32, ) ref_o = ref_o.to(q.dtype) ref_state = ref_state.to(kv_dtype)