diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 41b059f643..5ecafac109 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -18,21 +18,20 @@ GDN Decode Benchmark This benchmark supports: -1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state +1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose + bf16_state 2. Single layout comparison: FlashInfer (CuTe DSL) vs Triton kernel (--compare) 3. MTP benchmark (--version mtp) -4. gdn_decode_klast_bf16_state benchmark (--version gdn_decode_klast_bf16_state) for T=1,2,3,4 +4. BF16 state benchmark (--version bf16_state) for T=1 and MTP T>=1 Kernels benchmarked: - FlashInfer Pretranspose [B, HV, V, K] (V-major layout) - FlashInfer Nontranspose [B, HV, K, V] (K-major layout) - Triton Pretranspose [B, HV, V, K] - Triton Nontranspose [B, HV, K, V] -- gdn_decode_klast_bf16_state [B, HV, V, K] (K-fast layout, T=1..4, bf16 state) - from flashinfer.cute_dsl.gated_delta_rule +- BF16 State [B, HV, V, K] (K-fast layout, bf16 state, T=1 + MTP) Usage: - # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) + # Default: All layouts comparison python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 # Single layout comparison: FlashInfer vs Triton @@ -44,8 +43,8 @@ # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 - # gdn_decode_klast_bf16_state benchmark (T=1,2,3,4) - python benchmarks/bench_gdn_decode.py --version gdn_decode_klast_bf16_state --batch-size 1 32 128 512 + # BF16 state benchmark (T=1 and MTP) + python benchmarks/bench_gdn_decode.py --version bf16_state --batch-size 1 32 128 512 # Use Qwen3-Next preset (q=k=16, v=32, d=128) python benchmarks/bench_gdn_decode.py --preset qwen3-next --batch-size 1 32 128 512 @@ -62,15 +61,16 @@ ) from flashinfer.testing import bench_gpu_time -# Import the gdn_decode_klast_bf16_state kernel for benchmarking (T=1..4, bf16 state, K-last) +# Import BF16 state kernels for benchmarking try: from flashinfer.gdn_kernels.gdn_decode_bf16_state import ( - gated_delta_rule as gdn_decode_klast_bf16_state, + gated_delta_rule as gdn_decode_bf16_state, + gated_delta_rule_mtp as gdn_decode_bf16_state_mtp, ) - GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True + GDN_DECODE_BF16_STATE_AVAILABLE = True except ImportError: - GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False + GDN_DECODE_BF16_STATE_AVAILABLE = False # ============================================================================ # Utility Functions @@ -1832,8 +1832,8 @@ def verify_correctness_pretranspose( # ============================================================================ -def gdn_decode_klast_bf16_state_wrapper( - q: torch.Tensor, # [B, T, H_Q, K] where T=1,2,3,4 +def gdn_decode_bf16_state_wrapper( + q: torch.Tensor, # [B, T, H_Q, K] k: torch.Tensor, # [B, T, H_K, K] v: torch.Tensor, # [B, T, HV, V] state: torch.Tensor, # [B, HV, V, K] - K-last layout (pretranspose) @@ -1846,33 +1846,56 @@ def gdn_decode_klast_bf16_state_wrapper( use_qk_l2norm: bool = True, softplus_beta: float = 1.0, softplus_threshold: float = 20.0, + intermediate_states_buffer=None, + disable_state_update: bool = False, + initial_state_indices=None, ): """ - Wrapper for gdn_decode_klast_bf16_state GDN kernel. - Supports T=1,2,3,4 (sequence lengths up to 4). + Wrapper for gdn_decode_bf16_state GDN kernel. + Supports T=1 (calls gated_delta_rule) and T>1 (calls gated_delta_rule_mtp). Adapts the interface to match the benchmark's calling convention. Note: The kernel returns output directly, no copy needed. """ - if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: - raise RuntimeError("gdn_decode_klast_bf16_state kernel is not available") - - # Call gdn_decode_klast_bf16_state kernel directly - no wrapper overhead - # Kernel modifies state in-place and returns output tensor - return gdn_decode_klast_bf16_state( - A_log=A_log, - a=a, - dt_bias=dt_bias, - softplus_beta=softplus_beta, - softplus_threshold=softplus_threshold, - q=q, - k=k, - v=v, - b=b, - initial_state_source=state, - use_qk_l2norm_in_kernel=use_qk_l2norm, - scale=scale, - ) + if not GDN_DECODE_BF16_STATE_AVAILABLE: + raise RuntimeError("gdn_decode_bf16_state kernel is not available") + + # Dispatch to T=1 or MTP kernel + T = q.shape[1] + if T == 1: + return gdn_decode_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale, + ) + else: + return gdn_decode_bf16_state_mtp( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + initial_state_indices=initial_state_indices, + intermediate_states_buffer=intermediate_states_buffer, + disable_state_update=disable_state_update, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale, + output=output, + ) def format_time(t): @@ -2030,15 +2053,15 @@ def bench_all_layouts( results["tr_pretrans_us"] = None results["tr_nontrans_us"] = None - # ========== gdn_decode_klast_bf16_state Kernel (K-fast/pretranspose layout) ========== - if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: - # gdn_decode_klast_bf16_state uses [B, HV, V, K] layout (K-fast, same as pretranspose) + # ========== gdn_decode_bf16_state Kernel (K-fast/pretranspose layout) ========== + if GDN_DECODE_BF16_STATE_AVAILABLE: + # gdn_decode_bf16_state uses [B, HV, V, K] layout (K-fast, same as pretranspose) state = torch.randn( batch_size, num_sab_heads, head_size, head_size, - dtype=torch.bfloat16, # gdn_decode_klast_bf16_state uses BF16 state + dtype=torch.bfloat16, # gdn_decode_bf16_state uses BF16 state device="cuda", ) output = torch.empty( @@ -2047,21 +2070,19 @@ def bench_all_layouts( try: times = bench_gpu_time( - lambda: gdn_decode_klast_bf16_state_wrapper( + lambda: gdn_decode_bf16_state_wrapper( q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm ), enable_cupti=True, dry_run_iters=warmup_iters, repeat_iters=bench_iters, ) - results["gdn_decode_klast_bf16_state_us"] = np.median(times) * 1000 + results["gdn_decode_bf16_state_us"] = np.median(times) * 1000 except Exception as e: - results["gdn_decode_klast_bf16_state_us"] = None - print( - f" gdn_decode_klast_bf16_state kernel failed: {type(e).__name__}: {e}" - ) + results["gdn_decode_bf16_state_us"] = None + print(f" gdn_decode_bf16_state kernel failed: {type(e).__name__}: {e}") else: - results["gdn_decode_klast_bf16_state_us"] = None + results["gdn_decode_bf16_state_us"] = None return results @@ -2104,9 +2125,7 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print() print("\n" + "=" * 160) - print( - "GDN Decode Benchmark (T=1): FlashInfer vs Triton vs gdn_decode_klast_bf16_state" - ) + print("GDN Decode Benchmark (T=1): FlashInfer vs Triton vs gdn_decode_bf16_state") print( f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " f"v_heads={args.num_v_heads}, head_size={args.head_size}, " @@ -2115,8 +2134,8 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print("=" * 160) print() print( - f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | {'KlastBf16':>9} | " - f"{'FI/TR-Pre':>9} {'KlastBf16/FI':>11} {'KlastBf16/TR':>11}" + f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | {'Bf16State':>9} | " + f"{'FI/TR-Pre':>9} {'Bf16State/FI':>11} {'Bf16State/TR':>11}" ) print( f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} | " @@ -2143,21 +2162,21 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): fi_non = result.get("fi_nontrans_us") tr_pre = result.get("tr_pretrans_us") tr_non = result.get("tr_nontrans_us") - klast_bf16_us = result.get("gdn_decode_klast_bf16_state_us") + bf16_state_us = result.get("gdn_decode_bf16_state_us") # FI/TR speedup (>1 means FI faster) fi_tr_pre = format_speedup(fi_pre, tr_pre) - # gdn_decode_klast_bf16_state vs FI-PreTr speedup (>1 means klast_bf16 faster) - klast_bf16_fi_speedup = format_speedup(klast_bf16_us, fi_pre) + # BF16 state vs FI-PreTr speedup (>1 means BF16 state faster) + bf16_fi_speedup = format_speedup(bf16_state_us, fi_pre) - # gdn_decode_klast_bf16_state vs TR-PreTr speedup (>1 means klast_bf16 faster) - klast_bf16_tr_speedup = format_speedup(klast_bf16_us, tr_pre) + # BF16 state vs TR-PreTr speedup (>1 means BF16 state faster) + bf16_tr_speedup = format_speedup(bf16_state_us, tr_pre) print( f"{batch_size:>6} | {format_time(fi_pre)} {format_time(fi_non)} | " - f"{format_time(tr_pre)} {format_time(tr_non)} | {format_time(klast_bf16_us)} | " - f"{fi_tr_pre} {klast_bf16_fi_speedup:>10} {klast_bf16_tr_speedup:>10}" + f"{format_time(tr_pre)} {format_time(tr_non)} | {format_time(bf16_state_us)} | " + f"{fi_tr_pre} {bf16_fi_speedup:>10} {bf16_tr_speedup:>10}" ) print("-" * 160) @@ -2167,25 +2186,23 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): print(" FI-NonTr = FlashInfer Nontranspose [B, HV, K, V]") print(" TR-PreTr = Triton Pretranspose [B, HV, V, K]") print(" TR-NonTr = Triton Nontranspose [B, HV, K, V]") - print( - " KlastBf16 = gdn_decode_klast_bf16_state [B, HV, V, K] (K-fast layout, T=1..4, bf16 state)" - ) + print(" Bf16State = BF16 state kernel [B, HV, V, K] (bf16 state, T=1 + MTP)") print(" FI/TR speedup > 1.0 means FlashInfer is faster than Triton") print( - " KlastBf16/FI speedup > 1.0 means gdn_decode_klast_bf16_state is faster than FlashInfer Pretranspose" + " Bf16State/FI speedup > 1.0 means BF16 state is faster than FlashInfer Pretranspose" ) print( - " KlastBf16/TR speedup > 1.0 means gdn_decode_klast_bf16_state is faster than Triton Pretranspose" + " Bf16State/TR speedup > 1.0 means BF16 state is faster than Triton Pretranspose" ) print() # Summary statistics fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")] tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")] - klast_bf16_times = [ - r["gdn_decode_klast_bf16_state_us"] + bf16_state_times = [ + r["gdn_decode_bf16_state_us"] for r in all_results - if r.get("gdn_decode_klast_bf16_state_us") + if r.get("gdn_decode_bf16_state_us") ] if fi_pre_times and tr_pre_times: @@ -2194,47 +2211,47 @@ def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) - if klast_bf16_times and fi_pre_times and len(klast_bf16_times) == len(fi_pre_times): + if bf16_state_times and fi_pre_times and len(bf16_state_times) == len(fi_pre_times): speedups = [ - fi / t for t, fi in zip(klast_bf16_times, fi_pre_times, strict=False) + fi / t for t, fi in zip(bf16_state_times, fi_pre_times, strict=False) ] print( - f"gdn_decode_klast_bf16_state vs FlashInfer (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + f"BF16 state vs FlashInfer (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) - if klast_bf16_times and tr_pre_times and len(klast_bf16_times) == len(tr_pre_times): + if bf16_state_times and tr_pre_times and len(bf16_state_times) == len(tr_pre_times): speedups = [ - tr / t for t, tr in zip(klast_bf16_times, tr_pre_times, strict=False) + tr / t for t, tr in zip(bf16_state_times, tr_pre_times, strict=False) ] print( - f"gdn_decode_klast_bf16_state vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + f"BF16 state vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" ) # ============================================================================ -# gdn_decode_klast_bf16_state Multi-Token Benchmark (T=1,2,3,4) +# BF16 State Multi-Token Benchmark # ============================================================================ -def bench_gdn_decode_klast_bf16_state( +def bench_gdn_decode_bf16_state( batch_size: int, - seq_len: int, # T=1,2,3,4 + seq_len: int, num_q_heads: int, num_k_heads: int, num_v_heads: int, head_size: int, dtype: torch.dtype, use_qk_l2norm: bool = True, + cache_intermediate_states: bool = False, + disable_state_update: bool = False, warmup_iters: int = 10, bench_iters: int = 100, ): - """Benchmark gdn_decode_klast_bf16_state kernel for T=1,2,3,4.""" - if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: - raise RuntimeError("gdn_decode_klast_bf16_state kernel is not available") + """Benchmark BF16 state kernel.""" + if not GDN_DECODE_BF16_STATE_AVAILABLE: + raise RuntimeError("gdn_decode_bf16_state kernel is not available") - assert seq_len in [1, 2, 3, 4], ( - f"gdn_decode_klast_bf16_state supports T=1,2,3,4, got T={seq_len}" - ) + assert seq_len >= 1, f"seq_len must be >= 1, got T={seq_len}" num_o_heads = max(num_q_heads, num_v_heads) num_sab_heads = num_o_heads @@ -2261,18 +2278,45 @@ def bench_gdn_decode_klast_bf16_state( device="cuda", ) - # Pre-allocate output + # Intermediate states buffer (MTP only, when caching is enabled) + intermediate_states_buffer = None + if cache_intermediate_states and T > 1: + intermediate_states_buffer = torch.zeros( + batch_size, + T, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + + # Pre-allocate output and state indices (avoid per-call torch.arange overhead in CUPTI) output = torch.empty( batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" ) + initial_state_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda") # Scale factor scale = 1.0 / (head_size**0.5) # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) kernel_times_ms = bench_gpu_time( - lambda: gdn_decode_klast_bf16_state_wrapper( - q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + lambda: gdn_decode_bf16_state_wrapper( + q, + k, + v, + state, + A_log, + a, + dt_bias, + b, + scale, + output, + use_qk_l2norm, + intermediate_states_buffer=intermediate_states_buffer, + disable_state_update=disable_state_update, + initial_state_indices=initial_state_indices, ), enable_cupti=True, dry_run_iters=warmup_iters, @@ -2284,7 +2328,7 @@ def bench_gdn_decode_klast_bf16_state( flops = gdn_decode_flops( batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len ) - # gdn_decode_klast_bf16_state uses BF16 state (2 bytes), not FP32 (4 bytes) + # gdn_decode_bf16_state uses BF16 state (2 bytes), not FP32 (4 bytes) bytes_accessed = gdn_decode_bytes( batch_size, num_q_heads, @@ -2293,8 +2337,8 @@ def bench_gdn_decode_klast_bf16_state( head_size, dtype, seq_len, - disable_state_update=False, - state_dtype_bytes=2, # BF16 state for gdn_decode_klast_bf16_state + disable_state_update=disable_state_update, + state_dtype_bytes=2, # BF16 state for gdn_decode_bf16_state ) kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 @@ -2311,25 +2355,29 @@ def bench_gdn_decode_klast_bf16_state( } -def run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm): - """Run gdn_decode_klast_bf16_state benchmark for T=1,2,3,4.""" - if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: - print("Error: gdn_decode_klast_bf16_state kernel is not available.") - print("Make sure flashinfer.cute_dsl.gated_delta_rule is importable.") +def run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm): + """Run BF16 state benchmark for T=1 and MTP T>=1.""" + if not GDN_DECODE_BF16_STATE_AVAILABLE: + print("Error: BF16 state kernel is not available.") + print("Make sure flashinfer.gdn_kernels.gdn_decode_bf16_state is importable.") return - # Filter seq_len to only valid values (1,2,3,4) - valid_seq_lens = [t for t in args.seq_len if t in [1, 2, 3, 4]] + valid_seq_lens = [t for t in args.seq_len if t >= 1] if not valid_seq_lens: - print("Error: --seq-len must include values from [1, 2, 3, 4]") + print("Error: --seq-len must include values >= 1") return + cache_intermediate = getattr(args, "cache_intermediate_states", False) + disable_state_update = not getattr(args, "update_state", False) + print("\n" + "=" * 100) - print(f"gdn_decode_klast_bf16_state GDN Benchmark (T={valid_seq_lens})") + print(f"BF16 State GDN Benchmark (T={valid_seq_lens})") print( f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " f"v_heads={args.num_v_heads}, head_size={args.head_size}, " - f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" + f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}, " + f"cache_intermediate={'ON' if cache_intermediate else 'OFF'}, " + f"update_state={'ON' if not disable_state_update else 'OFF'}" ) print("=" * 100) print() @@ -2340,7 +2388,7 @@ def run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm): for batch_size in args.batch_size: for seq_len in valid_seq_lens: try: - result = bench_gdn_decode_klast_bf16_state( + result = bench_gdn_decode_bf16_state( batch_size=batch_size, seq_len=seq_len, num_q_heads=args.num_q_heads, @@ -2349,6 +2397,8 @@ def run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm): head_size=args.head_size, dtype=dtype, use_qk_l2norm=use_qk_l2norm, + cache_intermediate_states=cache_intermediate, + disable_state_update=disable_state_update, warmup_iters=args.warmup, bench_iters=args.iters, ) @@ -2684,8 +2734,8 @@ def main(): # MTP comparison: FlashInfer vs Triton python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 - # gdn_decode_klast_bf16_state benchmark (T=1,2,3,4) - python benchmarks/bench_gdn_decode.py --version gdn_decode_klast_bf16_state --batch-size 1 32 128 512 + # BF16 state benchmark (T=1 and MTP) + python benchmarks/bench_gdn_decode.py --version bf16_state --batch-size 1 32 128 512 """, ) parser.add_argument( @@ -2721,18 +2771,18 @@ def main(): "pretranspose", "nontranspose", "mtp", - "gdn_decode_klast_bf16_state", + "bf16_state", "all", ], default="nontranspose", - help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), gdn_decode_klast_bf16_state (T=1..4, bf16 state, K-last), or all", + help="Kernel version: pretranspose, nontranspose, mtp, bf16_state, or all", ) parser.add_argument( "--seq-len", type=int, nargs="+", default=[1, 2, 3, 4], - help="Sequence lengths: for MTP use T>1, for gdn_decode_klast_bf16_state use T=1,2,3,4", + help="Sequence lengths: for MTP use T>1, for bf16_state use any T>=1", ) parser.add_argument( "--cache-intermediate-states", @@ -2792,11 +2842,11 @@ def main(): run_comparison_benchmark(args, dtype, use_qk_l2norm) else: run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) - elif args.version == "gdn_decode_klast_bf16_state": - # gdn_decode_klast_bf16_state benchmark for T=1,2,3,4 - run_gdn_decode_klast_bf16_state_benchmark(args, dtype, use_qk_l2norm) + elif args.version == "bf16_state": + # BF16 state benchmark: T=1 and MTP T>=2 vs FP32 MTP + run_gdn_decode_bf16_state_benchmark(args, dtype, use_qk_l2norm) else: - # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_klast_bf16_state) + # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose + gdn_decode_bf16_state) run_all_layouts_benchmark(args, dtype, use_qk_l2norm) diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index ba5f3230c5..edd108c58b 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -45,16 +45,18 @@ def flashinfer_api(func): # type: ignore[misc] return func -# GDN decode K-last bf16 state kernel (T=1..4, bf16 state, K-last layout) - optional backend +# GDN decode BF16 state kernels - optional backend try: from .gdn_kernels.gdn_decode_bf16_state import ( - gated_delta_rule as _gated_delta_rule_gdn_decode_klast_bf16_state, + gated_delta_rule as _gated_delta_rule_bf16_state, + gated_delta_rule_mtp as _gated_delta_rule_bf16_state_mtp, ) - _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True + _GDN_DECODE_BF16_STATE_AVAILABLE = True except ImportError: - _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False - _gated_delta_rule_gdn_decode_klast_bf16_state = None + _GDN_DECODE_BF16_STATE_AVAILABLE = False + _gated_delta_rule_bf16_state = None + _gated_delta_rule_bf16_state_mtp = None # Pretranspose decode kernel (V-major state, T=1) try: @@ -132,8 +134,8 @@ def gated_delta_rule_decode_pretranspose( Current value of shape ``[B, 1, HV, V]``. Must be float16/bfloat16. state (Optional[torch.Tensor]): Current state of shape ``[B, HV, V, K]`` (v-major / K-last layout). - Float32: legacy kernel (T=1 only). Bfloat16: gdn_decode_klast_bf16_state backend - when T in 1..4 and K=V=128. Will be updated in-place. + Float32: legacy kernel (T=1 only). Bfloat16: BF16 state backend + (T=1 or MTP for T>1) when K=V=128. Will be updated in-place. Pass ``None`` when using ``initial_state`` / ``initial_state_indices`` instead. A_log (torch.Tensor): Log decay parameter of shape ``[HV]``. Must be float32. @@ -156,7 +158,7 @@ def gated_delta_rule_decode_pretranspose( When provided, the kernel gathers directly from the pool using ``initial_state_indices`` and writes updates back in-place — eliminating the caller-side gather/scatter overhead. - Requires bfloat16 state with T in 1..4 and K=V=128 (bf16 fast path). + Requires bfloat16 state with K=V=128 (bf16 fast path). initial_state_indices (Optional[torch.Tensor]): Per-batch indices of shape ``[B]`` (int32 or int64) mapping each batch entry to its slot in ``initial_state``. Required when ``initial_state`` @@ -172,10 +174,10 @@ def gated_delta_rule_decode_pretranspose( - State is always updated in-place; the pool path writes directly into ``initial_state`` memory (no separate scatter step needed) - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 - and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used - (supports both the direct ``state`` path and the pool+indices path). + and K=V=128, the BF16 state kernel is used (T=1 or MTP for T>1). + The pool+indices path routes through the MTP kernel. - pool+indices (``initial_state``/``initial_state_indices``) supported on - both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path + both the bf16 fast path (K=V=128) and the float32 legacy path (T=1). The float32 path also supports negative indices for padding. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ @@ -205,36 +207,53 @@ def gated_delta_rule_decode_pretranspose( f"Expected state shape [B={B}, HV={HV}, V={V}, K={K}], got {state.shape}" ) - # Backend: gdn_decode_klast_bf16_state when bf16 state, T<=4, K-last layout, K=V=128 + # Backend: BF16 state kernel when bf16 state, K=V=128 state_dtype = initial_state.dtype if use_pool else state.dtype - use_gdn_decode_klast_bf16_state = ( - _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE + use_bf16_state = ( + _GDN_DECODE_BF16_STATE_AVAILABLE and state_dtype == torch.bfloat16 - and T in (1, 2, 3, 4) and K == 128 and V == 128 ) - if use_gdn_decode_klast_bf16_state: + if use_bf16_state: assert q.dtype in (torch.float16, torch.bfloat16), ( f"q must be float16/bfloat16, got {q.dtype}" ) assert A_log.dtype == torch.float32, f"A_log must be float32, got {A_log.dtype}" scale_val = K**-0.5 if scale is None else scale - out = _gated_delta_rule_gdn_decode_klast_bf16_state( - A_log=A_log, - a=a, - dt_bias=dt_bias, - softplus_beta=1.0, - softplus_threshold=20.0, - q=q, - k=k, - v=v, - b=b, - initial_state_source=initial_state if use_pool else state, - initial_state_indices=initial_state_indices, - use_qk_l2norm_in_kernel=use_qk_l2norm, - scale=scale_val, - ) + if T == 1 and not use_pool: + # T=1 kernel does not accept initial_state_indices + out = _gated_delta_rule_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b, + initial_state_source=state, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale_val, + ) + else: + # MTP kernel supports T>=1 and pool+indices + out = _gated_delta_rule_bf16_state_mtp( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b, + initial_state_source=initial_state if use_pool else state, + initial_state_indices=initial_state_indices, + use_qk_l2norm_in_kernel=use_qk_l2norm, + scale=scale_val, + ) output_provided = output is not None target_dtype = output.dtype if output_provided else q.dtype if output is not None: diff --git a/flashinfer/gdn_kernels/__init__.py b/flashinfer/gdn_kernels/__init__.py index bdd75e732e..15f49a341c 100644 --- a/flashinfer/gdn_kernels/__init__.py +++ b/flashinfer/gdn_kernels/__init__.py @@ -8,27 +8,31 @@ provides high-performance CuTe DSL kernel implementations for specific use cases. Exported Kernels: -- gated_delta_rule: BF16 hidden state decode kernel (T=1,2,3,4) -- GatedDeltaRuleKernel: Kernel class for advanced usage +- gated_delta_rule: BF16 hidden state decode kernel (T=1) +- gated_delta_rule_mtp: BF16 hidden state MTP kernel (T>=1) +- gated_delta_rule_bf16state_cooprow: backward compat alias for gated_delta_rule +- gated_delta_rule_bf16state_cooprow_mtp: backward compat alias for gated_delta_rule_mtp - run_pretranspose_decode: Pretranspose (V-major) decode kernel - run_nontranspose_decode: Nontranspose (K-major) decode kernel - run_mtp_decode: Multi-token processing decode kernel - get_tile_v_mtp, get_vec_size_mtp: MTP hyperparameter helpers """ -from typing import Optional, Type - try: from .gdn_decode_bf16_state import ( gated_delta_rule, - GatedDeltaRuleKernel, + gated_delta_rule_mtp, + gated_delta_rule_bf16state_cooprow, # backward compat alias + gated_delta_rule_bf16state_cooprow_mtp, # backward compat alias ) _has_cute_dsl = True except ImportError: _has_cute_dsl = False gated_delta_rule = None # type: ignore - GatedDeltaRuleKernel: Optional[Type] = None # type: ignore + gated_delta_rule_mtp = None # type: ignore + gated_delta_rule_bf16state_cooprow = None # type: ignore + gated_delta_rule_bf16state_cooprow_mtp = None # type: ignore try: from .gdn_decode_pretranspose import run_pretranspose_decode @@ -49,7 +53,9 @@ __all__ = [ "gated_delta_rule", - "GatedDeltaRuleKernel", + "gated_delta_rule_mtp", + "gated_delta_rule_bf16state_cooprow", + "gated_delta_rule_bf16state_cooprow_mtp", "run_pretranspose_decode", "run_nontranspose_decode", "run_mtp_decode", diff --git a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py index 5388e1b870..fa4a4a4f4f 100644 --- a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -12,29 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. +""" -Gated Delta Rule Decode Kernel (BF16 Hidden State) - CuTe-DSL Implementation -============================================================================ - -RELOCATED: This file was previously located at flashinfer/cute_dsl/gated_delta_rule.py - and has been moved to flashinfer/gdn_decode/gdn_decode_bf16_state.py - to better reflect its domain-specific purpose (GDN decode with BF16 state). - -High-performance CUDA kernel implementing the Gated Delta Rule linear attention -mechanism for decode-phase inference, supporting sequence lengths T=1, T=2, T=3, T=4. - -Key Features: -- Unified kernel architecture: T=2/3/4 share a single compile-time specialized kernel - using Constexpr dispatch, while T=1 uses a separate kernel with persistent K-in-registers -- L2-normalized Q/K with configurable scale -- Gated exponential decay of hidden state H via softplus -- Delta rule updates: v_delta = beta * (v - pred) -- Bank-conflict-free cross-warp reductions -- Async H memory loading with aggressive pipelining -- BF16 tensors with FP32 compute for numerical stability -- GQA (grouped-query attention) support with configurable H (query) and HV (value) heads -- Uses scalar FP32 FMA operations for SM90+ (Hopper, Blackwell) compatibility -- Can be optimized with packed F32x2 FMA for SM100+ in future releases +""" +Gated Delta Rule Decode Kernel - BF16 Hidden State +=================================================== + +CuTe DSL kernel for GDN decode with BF16 hidden state storage. +Provides both T=1 (single token) and MTP (multi-token prediction) variants. + +Approach: +- Each warp processes ONE V-row at a time (4 warps = 4 V-rows per iteration) +- Each thread holds vec_size=4 K-elements, using warp-level shuffle reduction +- H state is loaded/stored as BF16, converted to FP32 in registers for compute +- cp.async pipeline with TILE_V=8 x TILE_K=128 tiles + +Architecture: +- 128 threads (4 warps x 32 threads) +- TILE_V=8 rows of H loaded per pipeline stage +- TILE_K=128 (full K dimension) +- Each thread: 4 K-elements (lane_id * 4 to lane_id * 4 + 3) +- Warp shuffle reduction across 32 threads for dot products + +Public API: +- gated_delta_rule(): T=1 single-token decode with BF16 state +- gated_delta_rule_mtp(): Multi-token prediction (T>=1) with BF16 state """ import math @@ -44,102 +46,41 @@ import cutlass.cute as cute import cuda.bindings.driver as cuda import torch -from cutlass import utils -from cutlass._mlir.dialects import nvvm +from cutlass.cute.nvgpu import cpasync from cutlass.cute.runtime import from_dlpack # ============================================================================== # CONSTANTS # ============================================================================== -H_SMEM_PADDING = 8 -H_SMEM_STRIDE = 128 + H_SMEM_PADDING - +TILE_V = 8 +TILE_K = 128 +NUM_STAGES = 2 +NUM_THREADS = 128 +NUM_BLOCKS_PER_STATE = 8 # 8 CTAs per (batch, head) for small batch # ============================================================================== -# SHARED HELPER FUNCTIONS +# CONSTANTS FOR ILP-OPTIMIZED KERNEL (large batch sizes) # ============================================================================== - - -@cute.jit -def write_h_chunk_to_smem(h_chunk_f32, h_sh_chunk, lane_idx, k_base): - """Write F32 register H chunk to BF16 SMEM.""" - for i in cutlass.range_constexpr(32): - h_sh_chunk[lane_idx, k_base + i] = h_chunk_f32[i].to(cutlass.BFloat16) - - -@cute.jit -def store_h_smem_to_gmem(h_sh_chunk, h_out, tidx, v_row_offset): - """Store H from SMEM to GMEM using 128-bit stores.""" - copy_bits = 128 - copy_elems = copy_bits // cutlass.BFloat16.width - - thr_layout = cute.make_layout((16, 8), stride=(8, 1)) - val_layout = cute.make_layout((1, copy_elems)) - - from cutlass.cute.nvgpu import CopyUniversalOp - - atom_store = cute.make_copy_atom( - CopyUniversalOp(), cutlass.BFloat16, num_bits_per_copy=copy_bits - ) - tiled_copy = cute.make_tiled_copy_tv(atom_store, thr_layout, val_layout) - thr_copy = tiled_copy.get_slice(tidx) - - for row_iter in cutlass.range_constexpr(2): - for col_iter in cutlass.range_constexpr(2): - s_tile = cute.local_tile(h_sh_chunk, (16, 64), (row_iter, col_iter)) - g_tile = cute.local_tile( - h_out, (16, 64), (row_iter + (v_row_offset // 16), col_iter) - ) - tS = thr_copy.partition_S(s_tile) - tD = thr_copy.partition_D(g_tile) - cute.copy(atom_store, tS, tD) - - -@cute.jit -def load_h_chunk_async(h_sh_chunk, h_global, tidx, row_offset): - """Load H chunk from GMEM to SMEM using async copy.""" - copy_bits = 128 - copy_elems = copy_bits // cutlass.BFloat16.width - - thr_layout = cute.make_layout((16, 8), stride=(8, 1)) - val_layout = cute.make_layout((1, copy_elems)) - - atom_async_copy = cute.make_copy_atom( - cute.nvgpu.cpasync.CopyG2SOp( - cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL - ), - cutlass.BFloat16, - num_bits_per_copy=copy_bits, - ) - tiled_copy = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) - thr_copy = tiled_copy.get_slice(tidx) - - for row_iter in cutlass.range_constexpr(2): - for col_iter in cutlass.range_constexpr(2): - g_tile = cute.local_tile( - h_global, (16, 64), (row_iter + (row_offset // 16), col_iter) - ) - s_tile = cute.local_tile(h_sh_chunk, (16, 64), (row_iter, col_iter)) - tS = thr_copy.partition_S(g_tile) - tD = thr_copy.partition_D(s_tile) - cute.copy(atom_async_copy, tS, tD) +TILE_V_ILP = 128 # V-tile size: each block processes all 128 V-rows +TILE_K_ILP = 128 # Full K dimension +NUM_THREADS_ILP = 128 # 4 warps +VEC_SIZE_ILP = 4 # Elements per thread along K (changed dynamically) +ILP_ROWS = ( + 8 # Process 8 V-rows simultaneously per group (optimal ILP for latency hiding) +) # ============================================================================== # FMA WRAPPER FUNCTIONS (SM90 Compatibility) # ============================================================================== -# Note: cute.arch.fma_packed_f32x2() generates F32x2 intrinsics that are NOT -# supported on SM90 (Hopper). These wrappers use scalar FMA operations that -# work on all SM90+ architectures. Future optimization: add architecture- -# specific variants for SM100+ (Blackwell) using packed intrinsics. +# cute.arch.fma_packed_f32x2() generates F32x2 intrinsics NOT supported on SM90. +# These wrappers use scalar FMA operations that work on all SM90+ architectures. +# On SM100+ (Blackwell), use_packed_fma=True selects the native packed path. @cute.jit def fma_pair_mul(a1, a2, b1, b2): - """Multiply two pairs: (a1, a2) * (b1, b2). - - Equivalent to fma_packed_f32x2 with c=(0,0), but compatible with SM90+. - """ + """Multiply two pairs: (a1*b1, a2*b2). SM90-compatible.""" result1 = a1 * b1 result2 = a2 * b2 return result1, result2 @@ -147,1772 +88,2374 @@ def fma_pair_mul(a1, a2, b1, b2): @cute.jit def fma_pair(a1, a2, b1, b2, c1, c2): - """FMA two pairs: (a1, a2) * (b1, b2) + (c1, c2). - - Equivalent to fma_packed_f32x2, but compatible with SM90+. - """ + """FMA two pairs: (a1*b1+c1, a2*b2+c2). SM90-compatible.""" result1 = a1 * b1 + c1 result2 = a2 * b2 + c2 return result1, result2 -@cute.jit -def compute_single_gate( - alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold -): - """Compute gate values (g_exp, beta) for a single token.""" - x = alpha + dt_bias_val - beta_x = softplus_beta * x - softplus_x = cutlass.Float32(0.0) - if beta_x <= softplus_threshold: - softplus_x = (cutlass.Float32(1.0) / softplus_beta) * cute.log( - cutlass.Float32(1.0) + cute.exp(beta_x, fastmath=True), fastmath=True - ) - else: - softplus_x = x - g = -cute.exp(A_log_val, fastmath=True) * softplus_x - g_exp = cute.exp(g, fastmath=True) - beta = cutlass.Float32(1.0) / ( - cutlass.Float32(1.0) + cute.exp(-beta_raw, fastmath=True) - ) - return g_exp, beta - - -@cute.jit -def normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps): - """L2-normalize Q and K vectors, then store to shared memory.""" - q_reg = cute.make_rmem_tensor((4,), cutlass.Float32) - k_reg = cute.make_rmem_tensor((4,), cutlass.Float32) - - for i in cutlass.range_constexpr(4): - q_reg[i] = q_head[lane_idx + i * 32].to(cutlass.Float32) - k_reg[i] = k_head[lane_idx + i * 32].to(cutlass.Float32) - - q_sum_sq = cutlass.Float32(0.0) - k_sum_sq = cutlass.Float32(0.0) - q_sum_sq2 = cutlass.Float32(0.0) - k_sum_sq2 = cutlass.Float32(0.0) - - for i in cutlass.range_constexpr(0, 4, 2): - q_sum_sq, q_sum_sq2 = fma_pair( - q_reg[i], q_reg[i + 1], q_reg[i], q_reg[i + 1], q_sum_sq, q_sum_sq2 - ) - k_sum_sq, k_sum_sq2 = fma_pair( - k_reg[i], k_reg[i + 1], k_reg[i], k_reg[i + 1], k_sum_sq, k_sum_sq2 - ) - - q_sum_sq = q_sum_sq + q_sum_sq2 - k_sum_sq = k_sum_sq + k_sum_sq2 - - for i in cutlass.range_constexpr(5): - q_sum_sq = q_sum_sq + cute.arch.shuffle_sync_bfly( - q_sum_sq, offset=1 << i, mask=0xFFFFFFFF - ) - k_sum_sq = k_sum_sq + cute.arch.shuffle_sync_bfly( - k_sum_sq, offset=1 << i, mask=0xFFFFFFFF - ) - - q_norm = cute.rsqrt(q_sum_sq + eps, fastmath=True) - k_norm = cute.rsqrt(k_sum_sq + eps, fastmath=True) - q_scale_factor = q_norm * scale - - for i in cutlass.range_constexpr(4): - q_sh[lane_idx + i * 32] = q_reg[i] * q_scale_factor - k_sh[lane_idx + i * 32] = k_reg[i] * k_norm - - -@cute.jit -def load_v_to_smem(v_head, v_sh, tidx): - """Load V values from GMEM to SMEM.""" - v_sh[tidx] = v_head[tidx].to(cutlass.Float32) - - -@cute.jit -def load_kq_chunk_from_smem(kq_sh, kq_chunk, k_base): - """Load K or Q chunk from SMEM to registers.""" - for i in cutlass.range_constexpr(32): - kq_chunk[i] = kq_sh[k_base + i] +# ============================================================================== +# KERNEL: T=1 with gdn_decode approach but BF16 state +# ============================================================================== -@cute.jit -def decay_h_from_smem_and_compute_pred( - h_sh_chunk, h_chunk, kq_chunk, g_exp, lane_idx, k_base +@cute.kernel +def gdn_decode_bf16state_cooprow_kernel( + tiled_copy_load: cute.TiledCopy, + h0_source: cute.Tensor, # [B*HV, V, K] as BF16 + smem_layout_staged: cute.Layout, + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + A_log: cute.Tensor, # [HV] + a: cute.Tensor, # [B, 1, HV] + dt_bias: cute.Tensor, # [HV] + q: cute.Tensor, # [B, 1, H, K] + k: cute.Tensor, # [B, 1, H, K] + v: cute.Tensor, # [B, 1, HV, V] + b: cute.Tensor, # [B, 1, HV] + o: cute.Tensor, # [B, 1, HV, V] - output + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], ): - """Load H from SMEM, apply decay, and compute pred = sum_k(h * k).""" - pred = cutlass.Float32(0.0) - pred2 = cutlass.Float32(0.0) - - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_sh_chunk[lane_idx, k_base + i].to(cutlass.Float32), - h_sh_chunk[lane_idx, k_base + i + 1].to(cutlass.Float32), - g_exp, - g_exp, - ) - - for i in cutlass.range_constexpr(0, 32, 2): - pred, pred2 = fma_pair( - h_chunk[i], h_chunk[i + 1], kq_chunk[i], kq_chunk[i + 1], pred, pred2 - ) - - pred = pred + pred2 - return pred - - -@cute.jit -def update_h_with_delta(h_chunk, kq_chunk, v_delta): - """Update H with delta: h = h + k * v_delta.""" - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair( - kq_chunk[i], kq_chunk[i + 1], v_delta, v_delta, h_chunk[i], h_chunk[i + 1] - ) - - -@cute.jit -def compute_output(h_chunk, kq_chunk): - """Compute output = sum_k(h * q).""" - out = cutlass.Float32(0.0) - out2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - out, out2 = fma_pair( - h_chunk[i], h_chunk[i + 1], kq_chunk[i], kq_chunk[i + 1], out, out2 - ) - out = out + out2 - return out - - -@cute.jit -def decay_h_in_place(h_chunk, g_exp): - """Apply decay to H in place: h = h * g_exp.""" - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_chunk[i], h_chunk[i + 1], g_exp, g_exp - ) - - -@cute.jit -def cross_warp_reduce_single(reduce_sh, slot, warp_idx, lane_idx, value): """ - Cross-warp reduction for a single value using bank-conflict-free layout. - Layout: [slot, lane_idx, warp_idx] + T=1 GDN decode kernel using the 'different approach': + - Pipeline loads TILE_V x TILE_K BF16 tiles of H from GMEM to SMEM + - Each warp processes 1 V-row (4 warps = 4 rows per TILE_V=8 iteration) + - Each thread: vec_size=4 K-elements with warp shuffle reduction + - H stored as BF16, compute in FP32 """ - reduce_sh[slot, lane_idx, warp_idx] = value - cute.arch.sync_threads() - reduced_value = ( - reduce_sh[slot, lane_idx, 0] - + reduce_sh[slot, lane_idx, 1] - + reduce_sh[slot, lane_idx, 2] - + reduce_sh[slot, lane_idx, 3] - ) - return reduced_value + tidx, _, _ = cute.arch.thread_idx() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + block_idx, _, _ = cute.arch.block_idx() + batch_idx = block_idx // NUM_BLOCKS_PER_STATE + batch_inner = block_idx % NUM_BLOCKS_PER_STATE + num_v_tiles_per_block = num_v_tiles // NUM_BLOCKS_PER_STATE + i_n = batch_idx // HV + i_hv = batch_idx % HV + i_h = i_hv // (HV // H) + i_t = 0 -@cute.jit -def cross_warp_reduce_two(reduce_sh, slot1, slot2, warp_idx, lane_idx, value1, value2): - """ - Cross-warp reduction for two values simultaneously using bank-conflict-free layout. - Layout: [slot, lane_idx, warp_idx] - """ - reduce_sh[slot1, lane_idx, warp_idx] = value1 - reduce_sh[slot2, lane_idx, warp_idx] = value2 - cute.arch.sync_threads() - reduced1 = ( - reduce_sh[slot1, lane_idx, 0] - + reduce_sh[slot1, lane_idx, 1] - + reduce_sh[slot1, lane_idx, 2] - + reduce_sh[slot1, lane_idx, 3] - ) - reduced2 = ( - reduce_sh[slot2, lane_idx, 0] - + reduce_sh[slot2, lane_idx, 1] - + reduce_sh[slot2, lane_idx, 2] - + reduce_sh[slot2, lane_idx, 3] - ) - return reduced1, reduced2 + smem = cutlass.utils.SmemAllocator() + # Allocate shared memory for H tile pipeline (BF16) + sData = smem.allocate_tensor(cutlass.BFloat16, smem_layout_staged, 128) -@cute.jit -def process_first_token( - h_sh_chunk_curr, - h_chunk, - kq_chunk, - k_sh, - q_sh, - v_sh, - reduce_sh, - o_head, - g_exp, - beta, - v_offset, - pred_slot, - warp_idx, - lane_idx, - k_base, -): - """ - Process the first token in a V-chunk (T=0). - - Load K from SMEM - - Decay H from SMEM and compute pred - - Cross-warp reduce pred (uses pred_slot) - - Update H with delta - - Load Q and compute output - Returns: out (partial output, not yet reduced) - """ - # Load K for this token - load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) + # Allocate shared memory for output (V elements, BF16) + sOutput = smem.allocate_tensor(cutlass.BFloat16, cute.make_layout((V,)), 16) - # Decay H from SMEM and compute pred = H * K - pred = decay_h_from_smem_and_compute_pred( - h_sh_chunk_curr, h_chunk, kq_chunk, g_exp, lane_idx, k_base - ) + # Allocate shared memory for v values (V elements, FP32) + sV = smem.allocate_tensor(cutlass.Float32, cute.make_layout((V,)), 16) - # Reduce pred across warps (slot 0 for first token) - pred_final = cross_warp_reduce_single( - reduce_sh, pred_slot, warp_idx, lane_idx, pred + # Register tensors for K, Q, and H (vec_size=4 per thread) + r_k = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - - # Compute delta and update H - v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta - update_h_with_delta(h_chunk, kq_chunk, v_delta) - - # Load Q and compute output - load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) - out = compute_output(h_chunk, kq_chunk) - - return out - - -@cute.jit -def process_middle_token( - h_chunk, - kq_chunk, - k_sh, - q_sh, - v_sh, - reduce_sh, - o_head_prev, - g_exp, - beta, - v_offset, - out_slot_prev, - pred_slot, - out_prev, - warp_idx, - lane_idx, - k_base, -): - """ - Process a middle token (T=1, T=2 for T=4 kernel). - - Decay H in place - - Load K, compute pred - - Joint reduction of (prev_out, this_pred) - - Store prev output - - Update H with delta - - Load Q and compute output - Returns: out (partial output, not yet reduced) - """ - # Decay H in place - decay_h_in_place(h_chunk, g_exp) - - # Load K and compute pred - load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) - pred = compute_output(h_chunk, kq_chunk) - - # Joint reduction: reduce out_prev and pred together - out_prev_final, pred_final = cross_warp_reduce_two( - reduce_sh, out_slot_prev, pred_slot, warp_idx, lane_idx, out_prev, pred + r_q = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - - # Store previous token's output - if warp_idx == 0: - o_head_prev[v_offset + lane_idx] = out_prev_final.to(cutlass.BFloat16) - - # Compute delta and update H - v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta - update_h_with_delta(h_chunk, kq_chunk, v_delta) - - # Load Q and compute output - load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) - out = compute_output(h_chunk, kq_chunk) - - return out - - -@cute.jit -def process_last_token_and_finish( - h_sh_chunk_curr, - h_chunk, - kq_chunk, - k_sh, - q_sh, - v_sh, - reduce_sh, - o_head_prev, - o_head_last, - g_exp, - beta, - v_offset, - out_slot_prev, - pred_slot, - out_slot_last, - out_prev, - warp_idx, - lane_idx, - k_base, -): - """ - Process the last token and finalize the V-chunk. - - Decay H in place - - Load K, compute pred - - Joint reduction of (prev_out, this_pred) - - Store prev output - - Update H with delta - - Compute last output and reduce - - Write H back to SMEM - - Store last output - """ - # Decay H in place - decay_h_in_place(h_chunk, g_exp) - - # Load K and compute pred - load_kq_chunk_from_smem(k_sh, kq_chunk, k_base) - pred = compute_output(h_chunk, kq_chunk) - - # Joint reduction: reduce out_prev and pred together - out_prev_final, pred_final = cross_warp_reduce_two( - reduce_sh, out_slot_prev, pred_slot, warp_idx, lane_idx, out_prev, pred + r_h = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - - # Store previous token's output - if warp_idx == 0: - o_head_prev[v_offset + lane_idx] = out_prev_final.to(cutlass.BFloat16) - - # Compute delta and update H - v_delta = (v_sh[v_offset + lane_idx] - pred_final) * beta - update_h_with_delta(h_chunk, kq_chunk, v_delta) - - # Compute last output - load_kq_chunk_from_smem(q_sh, kq_chunk, k_base) - out_last = compute_output(h_chunk, kq_chunk) - - # Final reduction and store - out_last_final = cross_warp_reduce_single( - reduce_sh, out_slot_last, warp_idx, lane_idx, out_last + # BF16 register tensors for vectorized loading + r_q_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_k_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_v_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - write_h_chunk_to_smem(h_chunk, h_sh_chunk_curr, lane_idx, k_base) - if warp_idx == 0: - o_head_last[v_offset + lane_idx] = out_last_final.to(cutlass.BFloat16) - - -# ============================================================================== -# UNIFIED V-CHUNK PROCESSING FOR SEQLEN=2/3/4 -# ============================================================================== + # Each thread's K-range: lane_id*4 .. lane_id*4+3 + k_start = lane_id * vec_size + + # Read gate values from GMEM early (hide latency during subsequent syncs) + r_A_log = cutlass.Float32(A_log[i_hv]) + r_a = cutlass.Float32(a[i_n, i_t, i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + r_b = cutlass.Float32(b[i_n, i_t, i_hv]) + + cute.arch.barrier() + + # Global memory views + gSrc_batch = h0_source[(batch_idx, None, None)] # (V, K) in BF16 + gDst = cute.local_tile(h0_source, (1, TILE_V, TILE_K), (batch_idx, None, 0)) + + # V-direction tiles + gSrc = cute.local_tile( + gSrc_batch, (TILE_V, TILE_K), (None, 0) + ) # (TILE_V, TILE_K, num_v_tiles) + + # Partition for async load + thr_copy_load = tiled_copy_load.get_slice(tidx) + + # =================================================================== + # Prefetch first pipeline stages + # =================================================================== + start_v_tiles = batch_inner * num_v_tiles_per_block + prefetch_count = cutlass.min(NUM_STAGES - 1, num_v_tiles_per_block) + for v_tiles in range(start_v_tiles, start_v_tiles + prefetch_count): + stage = (v_tiles - start_v_tiles) % NUM_STAGES + + gSrc_tile = gSrc[(None, None, v_tiles)] + sData_stage = sData[(None, None, stage)] + + thr_gSrc = thr_copy_load.partition_S(gSrc_tile) + thr_sData = thr_copy_load.partition_D(sData_stage) + + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + # Load q, k as BF16, convert to FP32 + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + # Load v as BF16, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) + + cute.arch.barrier() + + # =================================================================== + # Compute gate values: g_exp and beta + # =================================================================== + r_g = 0.0 + r_beta = 0.0 + if lane_id == 0: + x = r_a + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + + if beta_x <= softplus_threshold: + exp_beta_x = cute.exp(beta_x, fastmath=True) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x + + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) + r_g = cute.exp(r_g_value, fastmath=True) + + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) + + # =================================================================== + # L2 normalization of Q and K (if enabled) + # =================================================================== + if use_qk_l2norm: + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) -@cute.jit -def process_vchunk_unified_234( - h_sh_chunk_curr, - h_sh_chunk_prev, - h_out, - h_chunk, - kq_chunk, - k_sh0, - k_sh1, - k_sh2, - k_sh3, - q_sh0, - q_sh1, - q_sh2, - q_sh3, - v_sh0, - v_sh1, - v_sh2, - v_sh3, - reduce_sh, - o_head0, - o_head1, - o_head2, - o_head3, - g_exp0, - g_exp1, - g_exp2, - g_exp3, - beta0, - beta1, - beta2, - beta3, - v_offset, - prev_v_offset, - store_prev, - tidx, - warp_idx, - lane_idx, - k_base, - NUM_TOKENS: cutlass.Constexpr[int], -): - """ - Unified V-chunk processing for 2, 3, or 4 tokens using Constexpr parameter. + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k + + # Apply scale to Q + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale + + # =================================================================== + # Main loop: process V tiles + # =================================================================== + end_v_tiles = start_v_tiles + num_v_tiles_per_block + for v_tiles in range(start_v_tiles, end_v_tiles): + stage = (v_tiles - start_v_tiles) % NUM_STAGES + + # Wait for current stage + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Prefetch next tile + next_v_tiles = v_tiles + prefetch_count + if next_v_tiles < end_v_tiles: + next_stage = (next_v_tiles - start_v_tiles) % NUM_STAGES + + gSrc_next = gSrc[(None, None, next_v_tiles)] + sData_next = sData[(None, None, next_stage)] + + thr_gSrc = thr_copy_load.partition_S(gSrc_next) + thr_sData = thr_copy_load.partition_D(sData_next) + + cute.copy(tiled_copy_load, thr_gSrc, thr_sData) + cute.arch.cp_async_commit_group() + + # Process TILE_V rows, 4 rows at a time (one per warp) + for row in cutlass.range_constexpr(0, TILE_V, 4): + row_offset = tidx // 32 # = warp_idx + sum_hk = 0.0 + + # Load H from BF16 SMEM, convert to FP32 in registers + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) + ) + # Manual load + convert BF16 -> FP32 + for i in cutlass.range_constexpr(vec_size): + r_h[i] = cutlass.Float32(sData_tile[i]) + + # Decay H and compute dot product: sum_hk = sum(h * k) + for i in cutlass.range_constexpr(vec_size): + r_h[i] = r_h[i] * r_g + sum_hk += r_h[i] * r_k[i] + + # Warp-level reduction for sum_hk + for offset in [16, 8, 4, 2, 1]: + sum_hk += cute.arch.shuffle_sync_bfly( + sum_hk, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Delta update: v_delta = beta * (v - pred) + v_new = sV[v_tiles * TILE_V + row + row_offset] - sum_hk + v_new = v_new * r_beta + + # Update H and compute output dot product: sum_hq = sum(h * q) + sum_hq = 0.0 + for i in cutlass.range_constexpr(vec_size): + r_h[i] += r_k[i] * v_new + sum_hq += r_h[i] * r_q[i] + + # Write updated H back to GMEM as BF16 via gDst + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) + ) + for i in cutlass.range_constexpr(vec_size): + gDst_tile[i] = cutlass.BFloat16(r_h[i]) - This function handles V-chunk processing for all multi-token cases (T=2, T=3, T=4) - using compile-time specialization via NUM_TOKENS. + # Warp-level reduction for sum_hq + for offset in [16, 8, 4, 2, 1]: + sum_hq += cute.arch.shuffle_sync_bfly( + sum_hq, offset=offset, mask=-1, mask_and_clamp=31 + ) - Pattern: - - Token 0: First token (always) - - Tokens 1 to NUM_TOKENS-2: Middle tokens (compile-time unrolled) - - Token NUM_TOKENS-1: Last token (always) - """ - # Store previous H chunk if needed - if store_prev: - store_h_smem_to_gmem(h_sh_chunk_prev, h_out, tidx, prev_v_offset) - - # Token 0: First token processing (always executed) - out0 = process_first_token( - h_sh_chunk_curr, - h_chunk, - kq_chunk, - k_sh0, - q_sh0, - v_sh0, - reduce_sh, - o_head0, - g_exp0, - beta0, - v_offset, - 0, # pred_slot=0 - warp_idx, - lane_idx, - k_base, - ) + o_idx = v_tiles * TILE_V + row + row_offset + if lane_id == 0 and o_idx < V: + sOutput[o_idx] = cutlass.BFloat16(sum_hq) - # Compile-time dispatch based on NUM_TOKENS - if NUM_TOKENS == 2: - # For T=2: Token 1 is the last token - process_last_token_and_finish( - h_sh_chunk_curr, - h_chunk, - kq_chunk, - k_sh1, - q_sh1, - v_sh1, - reduce_sh, - o_head0, - o_head1, - g_exp1, - beta1, - v_offset, - 1, - 2, - 3, # out_slot_prev=1, pred_slot=2, out_slot_last=3 - out0, - warp_idx, - lane_idx, - k_base, - ) - elif NUM_TOKENS == 3: - # For T=3: Token 1 is middle, Token 2 is last - out1 = process_middle_token( - h_chunk, - kq_chunk, - k_sh1, - q_sh1, - v_sh1, - reduce_sh, - o_head0, - g_exp1, - beta1, - v_offset, - 1, - 2, # out_slot_prev=1, pred_slot=2 - out0, - warp_idx, - lane_idx, - k_base, - ) - process_last_token_and_finish( - h_sh_chunk_curr, - h_chunk, - kq_chunk, - k_sh2, - q_sh2, - v_sh2, - reduce_sh, - o_head1, - o_head2, - g_exp2, - beta2, - v_offset, - 3, - 4, - 5, # out_slot_prev=3, pred_slot=4, out_slot_last=5 - out1, - warp_idx, - lane_idx, - k_base, - ) - else: - # For T=4: Tokens 1,2 are middle, Token 3 is last - out1 = process_middle_token( - h_chunk, - kq_chunk, - k_sh1, - q_sh1, - v_sh1, - reduce_sh, - o_head0, - g_exp1, - beta1, - v_offset, - 1, - 2, # out_slot_prev=1, pred_slot=2 - out0, - warp_idx, - lane_idx, - k_base, - ) - out2 = process_middle_token( - h_chunk, - kq_chunk, - k_sh2, - q_sh2, - v_sh2, - reduce_sh, - o_head1, - g_exp2, - beta2, - v_offset, - 3, - 4, # out_slot_prev=3, pred_slot=4 - out1, - warp_idx, - lane_idx, - k_base, - ) - # Last token for NUM_TOKENS=4: Token 3 - process_last_token_and_finish( - h_sh_chunk_curr, - h_chunk, - kq_chunk, - k_sh3, - q_sh3, - v_sh3, - reduce_sh, - o_head2, - o_head3, - g_exp3, - beta3, - v_offset, - 5, - 6, - 7, # out_slot_prev=5, pred_slot=6, out_slot_last=7 - out2, - warp_idx, - lane_idx, - k_base, - ) + # =================================================================== + # Final writeback: output from SMEM to GMEM + # =================================================================== + cute.arch.barrier() + if tidx >= start_v_tiles * TILE_V and tidx < end_v_tiles * TILE_V: + o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] # ============================================================================== -# SEQLEN=1 KERNEL (Persistent K Optimization) +# KERNEL: ILP-OPTIMIZED T=1 with direct GMEM->register loads, 8-row ILP # ============================================================================== +# Architecture (matches MTP kernel pattern): +# - Grid: (B * HV * num_v_tiles, 1, 1) - each block handles one TILE_V chunk +# - 128 threads = 4 groups of 32 threads (full warps) +# - Each group processes TILE_V/4 V-rows total, 8 rows at a time (ILP=8) +# - H loaded directly from GMEM into registers via autovec_copy (128-bit BF16 loads) +# - No SMEM pipeline - ILP hides memory latency instead @cute.kernel -def gated_delta_rule_decode_kernel_seqlen1( - gQ: cute.Tensor, - gK: cute.Tensor, - gV: cute.Tensor, - ga: cute.Tensor, - gb: cute.Tensor, - gA_log: cute.Tensor, - gdt_bias: cute.Tensor, - gH: cute.Tensor, - gH_slot_indices: cute.Tensor, - gO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, +def gdn_decode_bf16state_ilp_kernel( + h0_source: cute.Tensor, # [B*HV, V, K] as BF16 (K-last, autovec_copy compatible) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + A_log: cute.Tensor, # [HV] + a: cute.Tensor, # [B, 1, HV] + dt_bias: cute.Tensor, # [HV] + q: cute.Tensor, # [B, 1, H, K] + k: cute.Tensor, # [B, 1, H, K] + v: cute.Tensor, # [B, 1, HV, V] + b: cute.Tensor, # [B, 1, HV] + o: cute.Tensor, # [B, 1, HV, V] - output + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + use_packed_fma: cutlass.Constexpr[bool], ): """ - Seqlen=1 kernel with persistent K optimization. - OPTIMIZATIONS: - 1. PERSISTENT K IN REGISTERS ONLY: K[k_base:k_base+32] kept for entire kernel - Q is reloaded per chunk (lower register pressure than V3) - 2. AGGRESSIVE PIPELINING: Load chunks 2 ahead, store during next compute - 3. [4,32] CROSS-WARP REDUCTION: Correct lane-preserving reduction + ILP-optimized T=1 GDN decode kernel with BF16 state. + Direct GMEM->register loads with 8-row ILP for high memory throughput. """ tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - HV = cutlass.Int32(gV.shape[2]) - H = cutlass.Int32(gQ.shape[2]) - - batch_idx = bidx // HV - value_head_idx = bidx % HV - query_head_idx = value_head_idx // (HV // H) - pool_batch_idx = gH_slot_indices[batch_idx] - if pool_batch_idx < 0: - pool_batch_idx = cutlass.Int32(0) - - smem = utils.SmemAllocator() - - # Compute gates using shared helper - alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) - beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) - A_log_val = gA_log[value_head_idx] - dt_bias_val = gdt_bias[value_head_idx] - g_exp, beta = compute_single_gate( - alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # 4 groups (= 4 warps), each full warp of 32 threads + threads_per_group: cutlass.Constexpr[int] = 32 # noqa: F841 + num_groups: cutlass.Constexpr[int] = 4 + group_idx = warp_idx + lane_in_group = lane_id + + batch_idx, _, _ = cute.arch.block_idx() + + # Decode block index: (i_n, i_hv, i_v) from batch_idx + i_v = batch_idx % num_v_tiles + tmp = batch_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + i_t = 0 + + # Load A_log and dt_bias once + r_A_log = cutlass.Float32(A_log[i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + + # No shared memory needed for ILP kernel (direct GMEM access) + + # Register arrays for q, k, and h (8 rows of vec_size=4 each) + r_q = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - - # Allocate SMEM - h_sh_chunk0 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + r_k = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - h_sh_chunk1 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + r_h = cute.make_rmem_tensor( + cute.make_layout((ILP_ROWS, vec_size), stride=(vec_size, 1)), cutlass.Float32 ) - h_sh_chunk2 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + + # BF16 register tensors for vectorized loading from BF16 state + # We use 4 separate BF16 register tensors for ILP loads + r_hb0 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - h_sh_chunk3 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + r_hb1 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - q_sh = smem.allocate_tensor(cutlass.Float32, 128) - k_sh = smem.allocate_tensor(cutlass.Float32, 128) - - # pred_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) - # out_sh = smem.allocate_tensor(cutlass.Float32, cute.make_layout((4, 32))) - pred_sh = smem.allocate_tensor( - cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + r_hb2 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - out_sh = smem.allocate_tensor( - cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + r_hb3 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - h_global = gH[(pool_batch_idx, value_head_idx, None, None)] - - # Launch first 2 async loads - load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) - nvvm.cp_async_commit_group() - load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) - nvvm.cp_async_commit_group() - - # L2 normalization - q_head = gQ[(batch_idx, 0, query_head_idx, None)] - k_head = gK[(batch_idx, 0, query_head_idx, None)] - - warp_idx = tidx // 32 - lane_idx = tidx % 32 - - # Use shared helper for Q/K normalization (only warp 0 does the work) - if warp_idx == 0: - normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) - - cute.arch.sync_threads() - - # Load V - v_head = gV[(batch_idx, 0, value_head_idx, None)] - v_sh = smem.allocate_tensor(cutlass.Float32, 128) - v_sh[tidx] = v_head[tidx].to(cutlass.Float32) - - # Registers: h_chunk + k_chunk (persistent) + qk_temp (reused for Q) - h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) - k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) # PERSISTENT K! - qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) - - k_base = warp_idx * 32 - - # Load K ONCE - keep for entire kernel - for i in cutlass.range_constexpr(32): - k_chunk[i] = k_sh[k_base + i] - - h_out = gH[(pool_batch_idx, value_head_idx, None, None)] - o_head = gO[(batch_idx, 0, value_head_idx, None)] - - # ======================================================================== - # CHUNK 0 - # ======================================================================== - nvvm.cp_async_wait_group(1) - cute.arch.sync_threads() - - pred = cutlass.Float32(0.0) - pred2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_sh_chunk0[lane_idx, k_base + i].to(cutlass.Float32), - h_sh_chunk0[lane_idx, k_base + i + 1].to(cutlass.Float32), - g_exp, - g_exp, - ) - for i in cutlass.range_constexpr(0, 32, 2): - pred, pred2 = fma_pair( - h_chunk[i], h_chunk[i + 1], k_chunk[i], k_chunk[i + 1], pred, pred2 - ) - pred = pred + pred2 - - pred_sh[lane_idx, warp_idx] = pred - cute.arch.sync_threads() - pred_final = ( - pred_sh[lane_idx, 0] - + pred_sh[lane_idx, 1] - + pred_sh[lane_idx, 2] - + pred_sh[lane_idx, 3] + r_hb4 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - v_val = (v_sh[lane_idx] - pred_final) * beta - - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair( - k_chunk[i], k_chunk[i + 1], v_val, v_val, h_chunk[i], h_chunk[i + 1] - ) - - # Load Q for output computation - for i in cutlass.range_constexpr(32): - qk_temp[i] = q_sh[k_base + i] - - out = cutlass.Float32(0.0) - out2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - out, out2 = fma_pair( - h_chunk[i], h_chunk[i + 1], qk_temp[i], qk_temp[i + 1], out, out2 - ) - out = out + out2 - - out_sh[lane_idx, warp_idx] = out - cute.arch.sync_threads() - out_final = ( - out_sh[lane_idx, 0] - + out_sh[lane_idx, 1] - + out_sh[lane_idx, 2] - + out_sh[lane_idx, 3] + r_hb5 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - write_h_chunk_to_smem(h_chunk, h_sh_chunk0, lane_idx, k_base) - if warp_idx == 0: - o_head[lane_idx] = out_final.to(cutlass.BFloat16) - - # ======================================================================== - # CHUNK 1 - # ======================================================================== - nvvm.cp_async_wait_group(0) - cute.arch.sync_threads() - - load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) - nvvm.cp_async_commit_group() - load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) - nvvm.cp_async_commit_group() - - store_h_smem_to_gmem(h_sh_chunk0, h_out, tidx, 0) - - pred = cutlass.Float32(0.0) - pred2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_sh_chunk1[lane_idx, k_base + i].to(cutlass.Float32), - h_sh_chunk1[lane_idx, k_base + i + 1].to(cutlass.Float32), - g_exp, - g_exp, - ) - for i in cutlass.range_constexpr(0, 32, 2): - pred, pred2 = fma_pair( - h_chunk[i], h_chunk[i + 1], k_chunk[i], k_chunk[i + 1], pred, pred2 - ) - pred = pred + pred2 - - pred_sh[lane_idx, warp_idx] = pred - cute.arch.sync_threads() - pred_final = ( - pred_sh[lane_idx, 0] - + pred_sh[lane_idx, 1] - + pred_sh[lane_idx, 2] - + pred_sh[lane_idx, 3] + r_hb6 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - v_val = (v_sh[32 + lane_idx] - pred_final) * beta - - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair( - k_chunk[i], k_chunk[i + 1], v_val, v_val, h_chunk[i], h_chunk[i + 1] - ) - - for i in cutlass.range_constexpr(32): - qk_temp[i] = q_sh[k_base + i] - - out = cutlass.Float32(0.0) - out2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - out, out2 = fma_pair( - h_chunk[i], h_chunk[i + 1], qk_temp[i], qk_temp[i + 1], out, out2 - ) - out = out + out2 - - out_sh[lane_idx, warp_idx] = out - cute.arch.sync_threads() - out_final = ( - out_sh[lane_idx, 0] - + out_sh[lane_idx, 1] - + out_sh[lane_idx, 2] - + out_sh[lane_idx, 3] + r_hb7 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - write_h_chunk_to_smem(h_chunk, h_sh_chunk1, lane_idx, k_base) - if warp_idx == 0: - o_head[32 + lane_idx] = out_final.to(cutlass.BFloat16) - - # ======================================================================== - # CHUNK 2 - # ======================================================================== - nvvm.cp_async_wait_group(1) - cute.arch.sync_threads() - - store_h_smem_to_gmem(h_sh_chunk1, h_out, tidx, 32) - - pred = cutlass.Float32(0.0) - pred2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_sh_chunk2[lane_idx, k_base + i].to(cutlass.Float32), - h_sh_chunk2[lane_idx, k_base + i + 1].to(cutlass.Float32), - g_exp, - g_exp, - ) - for i in cutlass.range_constexpr(0, 32, 2): - pred, pred2 = fma_pair( - h_chunk[i], h_chunk[i + 1], k_chunk[i], k_chunk[i + 1], pred, pred2 - ) - pred = pred + pred2 - - pred_sh[lane_idx, warp_idx] = pred - cute.arch.sync_threads() - pred_final = ( - pred_sh[lane_idx, 0] - + pred_sh[lane_idx, 1] - + pred_sh[lane_idx, 2] - + pred_sh[lane_idx, 3] + # BF16 register tensors for vectorized loading q, k + r_q_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - v_val = (v_sh[64 + lane_idx] - pred_final) * beta - - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair( - k_chunk[i], k_chunk[i + 1], v_val, v_val, h_chunk[i], h_chunk[i + 1] - ) - - for i in cutlass.range_constexpr(32): - qk_temp[i] = q_sh[k_base + i] - - out = cutlass.Float32(0.0) - out2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - out, out2 = fma_pair( - h_chunk[i], h_chunk[i + 1], qk_temp[i], qk_temp[i + 1], out, out2 - ) - out = out + out2 - - out_sh[lane_idx, warp_idx] = out - cute.arch.sync_threads() - out_final = ( - out_sh[lane_idx, 0] - + out_sh[lane_idx, 1] - + out_sh[lane_idx, 2] - + out_sh[lane_idx, 3] + r_k_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - write_h_chunk_to_smem(h_chunk, h_sh_chunk2, lane_idx, k_base) - if warp_idx == 0: - o_head[64 + lane_idx] = out_final.to(cutlass.BFloat16) - - # ======================================================================== - # CHUNK 3 - # ======================================================================== - nvvm.cp_async_wait_group(0) - cute.arch.sync_threads() - - store_h_smem_to_gmem(h_sh_chunk2, h_out, tidx, 64) - - pred = cutlass.Float32(0.0) - pred2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_sh_chunk3[lane_idx, k_base + i].to(cutlass.Float32), - h_sh_chunk3[lane_idx, k_base + i + 1].to(cutlass.Float32), - g_exp, - g_exp, - ) - for i in cutlass.range_constexpr(0, 32, 2): - pred, pred2 = fma_pair( - h_chunk[i], h_chunk[i + 1], k_chunk[i], k_chunk[i + 1], pred, pred2 - ) - pred = pred + pred2 - - pred_sh[lane_idx, warp_idx] = pred - cute.arch.sync_threads() - pred_final = ( - pred_sh[lane_idx, 0] - + pred_sh[lane_idx, 1] - + pred_sh[lane_idx, 2] - + pred_sh[lane_idx, 3] + # BF16 register tensors for vectorized V load and output store (8 elements) + r_v_bf16_vec = cute.make_rmem_tensor( + cute.make_layout((ILP_ROWS,), stride=(1,)), cutlass.BFloat16 ) - - v_val = (v_sh[96 + lane_idx] - pred_final) * beta - - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair( - k_chunk[i], k_chunk[i + 1], v_val, v_val, h_chunk[i], h_chunk[i + 1] - ) - - for i in cutlass.range_constexpr(32): - qk_temp[i] = q_sh[k_base + i] - - out = cutlass.Float32(0.0) - out2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - out, out2 = fma_pair( - h_chunk[i], h_chunk[i + 1], qk_temp[i], qk_temp[i + 1], out, out2 - ) - out = out + out2 - - out_sh[lane_idx, warp_idx] = out - cute.arch.sync_threads() - out_final = ( - out_sh[lane_idx, 0] - + out_sh[lane_idx, 1] - + out_sh[lane_idx, 2] - + out_sh[lane_idx, 3] + r_o_bf16_vec = cute.make_rmem_tensor( + cute.make_layout((ILP_ROWS,), stride=(1,)), cutlass.BFloat16 ) - write_h_chunk_to_smem(h_chunk, h_sh_chunk3, lane_idx, k_base) - if warp_idx == 0: - o_head[96 + lane_idx] = out_final.to(cutlass.BFloat16) - - cute.arch.sync_threads() - store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) + # Compute gate values: only lane 0 computes, then broadcast + r_a_val = cutlass.Float32(a[i_n, i_t, i_hv]) + r_b_val = cutlass.Float32(b[i_n, i_t, i_hv]) + + r_g = 0.0 + r_beta = 0.0 + if lane_id == 0: + x = r_a_val + r_dt_bias + beta_x = softplus_beta * x + softplus_x = 0.0 + if beta_x <= softplus_threshold: + exp_beta_x = cute.exp(beta_x, fastmath=True) + log_input = cutlass.Float32(1.0 + exp_beta_x) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) + softplus_x = cutlass.Float32( + (cutlass.Float32(1.0) / softplus_beta) * log_result + ) + else: + softplus_x = x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b_val, fastmath=True)) + r_g = cute.exp(r_g_value, fastmath=True) + + r_g = cute.arch.shuffle_sync(r_g, 0) + r_beta = cute.arch.shuffle_sync(r_beta, 0) + + # Load q, k as BF16, convert to FP32 + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + # L2 normalization of Q and K + if use_qk_l2norm: + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + sum_q = cute.arch.warp_reduction_sum(sum_q) + sum_k = cute.arch.warp_reduction_sum(sum_k) + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k + + # Apply scale to Q + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale + + # =================================================================== + # Main loop: process V rows with 8-row ILP + # =================================================================== + flat_state_idx = i_n * HV + i_hv + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + eighth_rows: cutlass.Constexpr[int] = rows_per_group // ILP_ROWS + + for row_oct in cutlass.range_constexpr(eighth_rows): + v_base = i_v * tile_v + group_idx * rows_per_group + row_oct * ILP_ROWS + v0 = v_base + v1 = v_base + 1 + v2 = v_base + 2 + v3 = v_base + 3 + v4 = v_base + 4 + v5 = v_base + 5 + v6 = v_base + 6 + v7 = v_base + 7 + + # Always true when tile_v=128, V=128, 4 groups * 8 ILP_ROWS * 4 iters = 128 + if True: + # Load h for ALL 8 V-rows: GMEM BF16 -> BF16 regs (vectorized) -> FP32 regs + ht0 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v0, lane_in_group) + ) + ht1 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v1, lane_in_group) + ) + ht2 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v2, lane_in_group) + ) + ht3 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v3, lane_in_group) + ) + ht4 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v4, lane_in_group) + ) + ht5 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v5, lane_in_group) + ) + ht6 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v6, lane_in_group) + ) + ht7 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v7, lane_in_group) + ) + # Vectorized BF16 loads (64-bit = 4 BF16 elements per load) + cute.autovec_copy(ht0, r_hb0) + cute.autovec_copy(ht1, r_hb1) + cute.autovec_copy(ht2, r_hb2) + cute.autovec_copy(ht3, r_hb3) + cute.autovec_copy(ht4, r_hb4) + cute.autovec_copy(ht5, r_hb5) + cute.autovec_copy(ht6, r_hb6) + cute.autovec_copy(ht7, r_hb7) + + # Convert BF16 -> FP32, apply decay, AND compute dot products h@k in single pass + # Using fma_packed_f32x2 for paired FMA operations + s0 = 0.0 + s1 = 0.0 + s2 = 0.0 + s3 = 0.0 + s4 = 0.0 + s5 = 0.0 + s6 = 0.0 + s7 = 0.0 + s0b = 0.0 + s1b = 0.0 + s2b = 0.0 + s3b = 0.0 + s4b = 0.0 + s5b = 0.0 + s6b = 0.0 + s7b = 0.0 + for i in cutlass.range_constexpr(0, vec_size, 2): + # Convert + decay for pairs of elements + if cutlass.const_expr(use_packed_fma): + r_h[0, i], r_h[0, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb0[i]), + cutlass.Float32(r_hb0[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[1, i], r_h[1, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb1[i]), + cutlass.Float32(r_hb1[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[2, i], r_h[2, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb2[i]), + cutlass.Float32(r_hb2[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[3, i], r_h[3, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb3[i]), + cutlass.Float32(r_hb3[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[4, i], r_h[4, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb4[i]), + cutlass.Float32(r_hb4[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[5, i], r_h[5, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb5[i]), + cutlass.Float32(r_hb5[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[6, i], r_h[6, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb6[i]), + cutlass.Float32(r_hb6[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[7, i], r_h[7, i + 1] = cute.arch.fma_packed_f32x2( + src_a=( + cutlass.Float32(r_hb7[i]), + cutlass.Float32(r_hb7[i + 1]), + ), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + else: + r_h[0, i], r_h[0, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb0[i]), + cutlass.Float32(r_hb0[i + 1]), + r_g, + r_g, + ) + r_h[1, i], r_h[1, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb1[i]), + cutlass.Float32(r_hb1[i + 1]), + r_g, + r_g, + ) + r_h[2, i], r_h[2, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb2[i]), + cutlass.Float32(r_hb2[i + 1]), + r_g, + r_g, + ) + r_h[3, i], r_h[3, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb3[i]), + cutlass.Float32(r_hb3[i + 1]), + r_g, + r_g, + ) + r_h[4, i], r_h[4, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb4[i]), + cutlass.Float32(r_hb4[i + 1]), + r_g, + r_g, + ) + r_h[5, i], r_h[5, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb5[i]), + cutlass.Float32(r_hb5[i + 1]), + r_g, + r_g, + ) + r_h[6, i], r_h[6, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb6[i]), + cutlass.Float32(r_hb6[i + 1]), + r_g, + r_g, + ) + r_h[7, i], r_h[7, i + 1] = fma_pair_mul( + cutlass.Float32(r_hb7[i]), + cutlass.Float32(r_hb7[i + 1]), + r_g, + r_g, + ) + # Dot product h@k using paired FMA + if cutlass.const_expr(use_packed_fma): + s0, s0b = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s0, s0b), + ) + s1, s1b = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s1, s1b), + ) + s2, s2b = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s2, s2b), + ) + s3, s3b = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s3, s3b), + ) + s4, s4b = cute.arch.fma_packed_f32x2( + src_a=(r_h[4, i], r_h[4, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s4, s4b), + ) + s5, s5b = cute.arch.fma_packed_f32x2( + src_a=(r_h[5, i], r_h[5, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s5, s5b), + ) + s6, s6b = cute.arch.fma_packed_f32x2( + src_a=(r_h[6, i], r_h[6, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s6, s6b), + ) + s7, s7b = cute.arch.fma_packed_f32x2( + src_a=(r_h[7, i], r_h[7, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s7, s7b), + ) + else: + s0, s0b = fma_pair( + r_h[0, i], r_h[0, i + 1], r_k[i], r_k[i + 1], s0, s0b + ) + s1, s1b = fma_pair( + r_h[1, i], r_h[1, i + 1], r_k[i], r_k[i + 1], s1, s1b + ) + s2, s2b = fma_pair( + r_h[2, i], r_h[2, i + 1], r_k[i], r_k[i + 1], s2, s2b + ) + s3, s3b = fma_pair( + r_h[3, i], r_h[3, i + 1], r_k[i], r_k[i + 1], s3, s3b + ) + s4, s4b = fma_pair( + r_h[4, i], r_h[4, i + 1], r_k[i], r_k[i + 1], s4, s4b + ) + s5, s5b = fma_pair( + r_h[5, i], r_h[5, i + 1], r_k[i], r_k[i + 1], s5, s5b + ) + s6, s6b = fma_pair( + r_h[6, i], r_h[6, i + 1], r_k[i], r_k[i + 1], s6, s6b + ) + s7, s7b = fma_pair( + r_h[7, i], r_h[7, i + 1], r_k[i], r_k[i + 1], s7, s7b + ) + # Combine paired accumulators + s0 = s0 + s0b + s1 = s1 + s1b + s2 = s2 + s2b + s3 = s3 + s3b + s4 = s4 + s4b + s5 = s5 + s5b + s6 = s6 + s6b + s7 = s7 + s7b + + # Interleaved butterfly reduction for all 8 s-values (better ILP than sequential warp_reduction_sum) + for offset in [16, 8, 4, 2, 1]: + s0 += cute.arch.shuffle_sync_bfly( + s0, offset=offset, mask=-1, mask_and_clamp=31 + ) + s1 += cute.arch.shuffle_sync_bfly( + s1, offset=offset, mask=-1, mask_and_clamp=31 + ) + s2 += cute.arch.shuffle_sync_bfly( + s2, offset=offset, mask=-1, mask_and_clamp=31 + ) + s3 += cute.arch.shuffle_sync_bfly( + s3, offset=offset, mask=-1, mask_and_clamp=31 + ) + s4 += cute.arch.shuffle_sync_bfly( + s4, offset=offset, mask=-1, mask_and_clamp=31 + ) + s5 += cute.arch.shuffle_sync_bfly( + s5, offset=offset, mask=-1, mask_and_clamp=31 + ) + s6 += cute.arch.shuffle_sync_bfly( + s6, offset=offset, mask=-1, mask_and_clamp=31 + ) + s7 += cute.arch.shuffle_sync_bfly( + s7, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Step 3: Delta rule update - vectorized V load (8 consecutive BF16 elements) + vt_slice = cute.local_tile( + v, (1, 1, 1, ILP_ROWS), (i_n, i_t, i_hv, v_base // ILP_ROWS) + ) + cute.autovec_copy(vt_slice, r_v_bf16_vec) + vn0 = (cutlass.Float32(r_v_bf16_vec[0]) - s0) * r_beta + vn1 = (cutlass.Float32(r_v_bf16_vec[1]) - s1) * r_beta + vn2 = (cutlass.Float32(r_v_bf16_vec[2]) - s2) * r_beta + vn3 = (cutlass.Float32(r_v_bf16_vec[3]) - s3) * r_beta + vn4 = (cutlass.Float32(r_v_bf16_vec[4]) - s4) * r_beta + vn5 = (cutlass.Float32(r_v_bf16_vec[5]) - s5) * r_beta + vn6 = (cutlass.Float32(r_v_bf16_vec[6]) - s6) * r_beta + vn7 = (cutlass.Float32(r_v_bf16_vec[7]) - s7) * r_beta + + # Step 4: Rank-1 update + output dot products h@q using fma_packed_f32x2 + o0 = 0.0 + o1 = 0.0 + o2 = 0.0 + o3 = 0.0 + o4 = 0.0 + o5 = 0.0 + o6 = 0.0 + o7 = 0.0 + o0b = 0.0 + o1b = 0.0 + o2b = 0.0 + o3b = 0.0 + o4b = 0.0 + o5b = 0.0 + o6b = 0.0 + o7b = 0.0 + for i in cutlass.range_constexpr(0, vec_size, 2): + # Rank-1 update: h += k * vn (paired FMA) + if cutlass.const_expr(use_packed_fma): + r_h[0, i], r_h[0, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn0, vn0), + src_c=(r_h[0, i], r_h[0, i + 1]), + ) + r_h[1, i], r_h[1, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn1, vn1), + src_c=(r_h[1, i], r_h[1, i + 1]), + ) + r_h[2, i], r_h[2, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn2, vn2), + src_c=(r_h[2, i], r_h[2, i + 1]), + ) + r_h[3, i], r_h[3, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn3, vn3), + src_c=(r_h[3, i], r_h[3, i + 1]), + ) + r_h[4, i], r_h[4, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn4, vn4), + src_c=(r_h[4, i], r_h[4, i + 1]), + ) + r_h[5, i], r_h[5, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn5, vn5), + src_c=(r_h[5, i], r_h[5, i + 1]), + ) + r_h[6, i], r_h[6, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn6, vn6), + src_c=(r_h[6, i], r_h[6, i + 1]), + ) + r_h[7, i], r_h[7, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn7, vn7), + src_c=(r_h[7, i], r_h[7, i + 1]), + ) + else: + r_h[0, i], r_h[0, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn0, vn0, r_h[0, i], r_h[0, i + 1] + ) + r_h[1, i], r_h[1, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn1, vn1, r_h[1, i], r_h[1, i + 1] + ) + r_h[2, i], r_h[2, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn2, vn2, r_h[2, i], r_h[2, i + 1] + ) + r_h[3, i], r_h[3, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn3, vn3, r_h[3, i], r_h[3, i + 1] + ) + r_h[4, i], r_h[4, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn4, vn4, r_h[4, i], r_h[4, i + 1] + ) + r_h[5, i], r_h[5, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn5, vn5, r_h[5, i], r_h[5, i + 1] + ) + r_h[6, i], r_h[6, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn6, vn6, r_h[6, i], r_h[6, i + 1] + ) + r_h[7, i], r_h[7, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn7, vn7, r_h[7, i], r_h[7, i + 1] + ) + # Output dot product: o += h * q (paired FMA) + if cutlass.const_expr(use_packed_fma): + o0, o0b = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o0, o0b), + ) + o1, o1b = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o1, o1b), + ) + o2, o2b = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o2, o2b), + ) + o3, o3b = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o3, o3b), + ) + o4, o4b = cute.arch.fma_packed_f32x2( + src_a=(r_h[4, i], r_h[4, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o4, o4b), + ) + o5, o5b = cute.arch.fma_packed_f32x2( + src_a=(r_h[5, i], r_h[5, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o5, o5b), + ) + o6, o6b = cute.arch.fma_packed_f32x2( + src_a=(r_h[6, i], r_h[6, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o6, o6b), + ) + o7, o7b = cute.arch.fma_packed_f32x2( + src_a=(r_h[7, i], r_h[7, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o7, o7b), + ) + else: + o0, o0b = fma_pair( + r_h[0, i], r_h[0, i + 1], r_q[i], r_q[i + 1], o0, o0b + ) + o1, o1b = fma_pair( + r_h[1, i], r_h[1, i + 1], r_q[i], r_q[i + 1], o1, o1b + ) + o2, o2b = fma_pair( + r_h[2, i], r_h[2, i + 1], r_q[i], r_q[i + 1], o2, o2b + ) + o3, o3b = fma_pair( + r_h[3, i], r_h[3, i + 1], r_q[i], r_q[i + 1], o3, o3b + ) + o4, o4b = fma_pair( + r_h[4, i], r_h[4, i + 1], r_q[i], r_q[i + 1], o4, o4b + ) + o5, o5b = fma_pair( + r_h[5, i], r_h[5, i + 1], r_q[i], r_q[i + 1], o5, o5b + ) + o6, o6b = fma_pair( + r_h[6, i], r_h[6, i + 1], r_q[i], r_q[i + 1], o6, o6b + ) + o7, o7b = fma_pair( + r_h[7, i], r_h[7, i + 1], r_q[i], r_q[i + 1], o7, o7b + ) + # Combine paired accumulators + o0 = o0 + o0b + o1 = o1 + o1b + o2 = o2 + o2b + o3 = o3 + o3b + o4 = o4 + o4b + o5 = o5 + o5b + o6 = o6 + o6b + o7 = o7 + o7b + + # Interleaved butterfly reduction for all 8 o-values (better ILP than sequential warp_reduction_sum) + for offset in [16, 8, 4, 2, 1]: + o0 += cute.arch.shuffle_sync_bfly( + o0, offset=offset, mask=-1, mask_and_clamp=31 + ) + o1 += cute.arch.shuffle_sync_bfly( + o1, offset=offset, mask=-1, mask_and_clamp=31 + ) + o2 += cute.arch.shuffle_sync_bfly( + o2, offset=offset, mask=-1, mask_and_clamp=31 + ) + o3 += cute.arch.shuffle_sync_bfly( + o3, offset=offset, mask=-1, mask_and_clamp=31 + ) + o4 += cute.arch.shuffle_sync_bfly( + o4, offset=offset, mask=-1, mask_and_clamp=31 + ) + o5 += cute.arch.shuffle_sync_bfly( + o5, offset=offset, mask=-1, mask_and_clamp=31 + ) + o6 += cute.arch.shuffle_sync_bfly( + o6, offset=offset, mask=-1, mask_and_clamp=31 + ) + o7 += cute.arch.shuffle_sync_bfly( + o7, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Write output: pack into BF16 reg tensor and vectorized store + if lane_in_group == 0: + r_o_bf16_vec[0] = cutlass.BFloat16(o0) + r_o_bf16_vec[1] = cutlass.BFloat16(o1) + r_o_bf16_vec[2] = cutlass.BFloat16(o2) + r_o_bf16_vec[3] = cutlass.BFloat16(o3) + r_o_bf16_vec[4] = cutlass.BFloat16(o4) + r_o_bf16_vec[5] = cutlass.BFloat16(o5) + r_o_bf16_vec[6] = cutlass.BFloat16(o6) + r_o_bf16_vec[7] = cutlass.BFloat16(o7) + ot_slice = cute.local_tile( + o, (1, 1, 1, ILP_ROWS), (i_n, i_t, i_hv, v_base // ILP_ROWS) + ) + cute.autovec_copy(r_o_bf16_vec, ot_slice) + + # Write updated H back to GMEM: FP32 regs -> BF16 regs -> GMEM BF16 (vectorized) + for i in cutlass.range_constexpr(vec_size): + r_hb0[i] = cutlass.BFloat16(r_h[0, i]) + r_hb1[i] = cutlass.BFloat16(r_h[1, i]) + r_hb2[i] = cutlass.BFloat16(r_h[2, i]) + r_hb3[i] = cutlass.BFloat16(r_h[3, i]) + r_hb4[i] = cutlass.BFloat16(r_h[4, i]) + r_hb5[i] = cutlass.BFloat16(r_h[5, i]) + r_hb6[i] = cutlass.BFloat16(r_h[6, i]) + r_hb7[i] = cutlass.BFloat16(r_h[7, i]) + cute.autovec_copy(r_hb0, ht0) + cute.autovec_copy(r_hb1, ht1) + cute.autovec_copy(r_hb2, ht2) + cute.autovec_copy(r_hb3, ht3) + cute.autovec_copy(r_hb4, ht4) + cute.autovec_copy(r_hb5, ht5) + cute.autovec_copy(r_hb6, ht6) + cute.autovec_copy(r_hb7, ht7) # ============================================================================== -# UNIFIED SEQLEN=2/3/4 MAIN KERNEL +# KERNEL: MTP (Multiple Token Processing) with BF16 state # ============================================================================== +# Architecture (adapted from gdn_verify_kernel_mtp_original in gdn_decode.py): +# - Grid: (B * HV * num_v_tiles, 1, 1) - each block handles one TILE_V chunk +# - 128 threads = 4 groups of 32 threads (full warps) +# - Each group processes tile_v/4 V-rows +# - H loaded as BF16 from GMEM, computed in FP32, stored back as BF16 +# - Processes T tokens sequentially, keeping h in FP32 registers +# - Optional: cache intermediate states, disable state update + +MTP_TILE_K = 128 +MTP_NUM_THREADS = 128 +MTP_VEC_SIZE = 4 # 32 threads per group × 4 = 128 K elements +MTP_ILP_ROWS = 8 # Process 8 V-rows simultaneously per group iteration @cute.kernel -def gated_delta_rule_decode_kernel_seqlen234_unified( - gQ: cute.Tensor, # [B, T=2/3/4, H, K=128] - gK: cute.Tensor, # [B, T=2/3/4, H, K=128] - gV: cute.Tensor, # [B, T=2/3/4, HV, V=128] - ga: cute.Tensor, # [B, T=2/3/4, HV] - gb: cute.Tensor, # [B, T=2/3/4, HV] - gA_log: cute.Tensor, # [HV] - gdt_bias: cute.Tensor, # [HV] - gH: cute.Tensor, # [pool, HV, V=128, K=128] - K-fast layout - gH_slot_indices: cute.Tensor, # [B] indices mapping batch -> pool slot - gO: cute.Tensor, # [B, T=2/3/4, HV, V=128] - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, - NUM_TOKENS: cutlass.Constexpr[int], # 2, 3, or 4 +def gdn_decode_bf16state_mtp_kernel( + h0_source: cute.Tensor, # [pool_size * HV, V, K] as BF16 + intermediate_states: cute.Tensor, # [pool_size * T * HV, V, K] as BF16 (or dummy) + vec_size: cutlass.Constexpr[int], + num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], + A_log: cute.Tensor, # [HV] + a: cute.Tensor, # [B, T, HV] + dt_bias: cute.Tensor, # [HV] + q: cute.Tensor, # [B, T, H, K] + k: cute.Tensor, # [B, T, H, K] + v: cute.Tensor, # [B, T, HV, V] + b: cute.Tensor, # [B, T, HV] + o: cute.Tensor, # [B, T, HV, V] - output + h0_indices: cute.Tensor, # [B] - initial state indices + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + use_packed_fma: cutlass.Constexpr[bool], ): """ - Unified kernel for Seqlen=2, Seqlen=3 and Seqlen=4 with compile-time specialization. - - Uses cutlass.Constexpr[int] NUM_TOKENS parameter to eliminate dead code paths: - - NUM_TOKENS=2: 4-slot reduce_sh, 2 Q/K/V buffers, 2 gates - - NUM_TOKENS=3: 6-slot reduce_sh, 3 Q/K/V buffers, 3 gates - - NUM_TOKENS=4: 8-slot reduce_sh, 4 Q/K/V buffers, 4 gates + ILP-optimized MTP kernel for BF16 state: processes T tokens sequentially. + Each block handles one tile_v chunk of V rows. + H is loaded as BF16, computed in FP32, stored back as BF16. + Uses 8-row ILP with fma_packed_f32x2 (Blackwell) / scalar FMA (Hopper) with compile-time dispatch. """ tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - HV = cutlass.Int32(gV.shape[2]) - H = cutlass.Int32(gQ.shape[2]) - - batch_idx = bidx // HV - value_head_idx = bidx % HV - query_head_idx = value_head_idx // (HV // H) - pool_batch_idx = gH_slot_indices[batch_idx] - if pool_batch_idx < 0: - pool_batch_idx = cutlass.Int32(0) - - warp_idx = tidx // 32 - lane_idx = tidx % 32 - k_base = warp_idx * 32 - - smem = utils.SmemAllocator() + lane_id = tidx % 32 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # 4 groups (= 4 warps), each full warp of 32 threads + threads_per_group: cutlass.Constexpr[int] = 32 # noqa: F841 + num_groups: cutlass.Constexpr[int] = 4 + group_idx = warp_idx + lane_in_group = lane_id + + batch_idx, _, _ = cute.arch.block_idx() + + # Decode block index: (i_n, i_hv, i_v) from batch_idx + i_v = batch_idx % num_v_tiles + tmp = batch_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV + i_h = i_hv // (HV // H) + + # Get initial state index for this batch + cache_idx = h0_indices[i_n] + + # Load A_log and dt_bias once + r_A_log = cutlass.Float32(A_log[i_hv]) + r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + + # For T>1: shared SMEM for q/k (one copy, all warps read) + # Precomputed in parallel: warp i handles token i (barrier before inner loop) + # For T>2: also cache g/beta in SMEM (saves redundant exp/log across row_oct iterations) + # For T=1: no SMEM needed (inline compute is faster) + if cutlass.const_expr(T > 1): + smem = cutlass.utils.SmemAllocator() + sQ = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 + ) + sK = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 + ) + # Always allocate sGB (SMEM variable must exist for all T>1 paths) + sGB = smem.allocate_tensor( + cutlass.Float32, cute.make_layout((T, 2), stride=(2, 1)), 16 + ) - # SMEM Allocation - H chunks - h_sh_chunk0 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + # Register arrays for computation - ILP=8 rows of vec_size=4 each + r_q = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - h_sh_chunk1 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + r_k = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) - h_sh_chunk2 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + r_h = cute.make_rmem_tensor( + cute.make_layout((MTP_ILP_ROWS, vec_size), stride=(vec_size, 1)), + cutlass.Float32, ) - h_sh_chunk3 = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + # BF16 register tensors for vectorized loading q, k + r_q_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - # Q/K buffers for tokens 0 and 1 (always needed for T>=2) - q_sh0 = smem.allocate_tensor(cutlass.Float32, 128) - k_sh0 = smem.allocate_tensor(cutlass.Float32, 128) - q_sh1 = smem.allocate_tensor(cutlass.Float32, 128) - k_sh1 = smem.allocate_tensor(cutlass.Float32, 128) - - # Q/K buffers for token 2 (only for NUM_TOKENS >= 3) - q_sh2 = smem.allocate_tensor(cutlass.Float32, 128) - k_sh2 = smem.allocate_tensor(cutlass.Float32, 128) - - # Q/K buffers for token 3 (only for NUM_TOKENS=4) - q_sh3 = smem.allocate_tensor(cutlass.Float32, 128) - k_sh3 = smem.allocate_tensor(cutlass.Float32, 128) - - # V buffers - v_sh0 = smem.allocate_tensor(cutlass.Float32, 128) - v_sh1 = smem.allocate_tensor(cutlass.Float32, 128) - v_sh2 = smem.allocate_tensor(cutlass.Float32, 128) - v_sh3 = smem.allocate_tensor(cutlass.Float32, 128) - - # Bank-conflict-free reduce_sh: [slot, lane_idx, warp_idx] - reduce_sh = smem.allocate_tensor( - cutlass.Float32, cute.make_layout((8, 32, 4), stride=(128, 4, 1)) + r_k_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - # Register allocation - h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) - kq_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) - - # Gate computation - always compute gates 0, 1 (for T>=2) - A_log_val = gA_log[value_head_idx] - dt_bias_val = gdt_bias[value_head_idx] - - alpha0 = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) - beta_raw0 = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) - g_exp0, beta0 = compute_single_gate( - alpha0, beta_raw0, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + # 8 separate BF16 register tensors for vectorized H loading (autovec_copy) + r_hb0 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 ) - - alpha1 = ga[(batch_idx, 1, value_head_idx)].to(cutlass.Float32) - beta_raw1 = gb[(batch_idx, 1, value_head_idx)].to(cutlass.Float32) - g_exp1, beta1 = compute_single_gate( - alpha1, beta_raw1, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + r_hb1 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_hb2 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_hb3 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_hb4 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_hb5 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_hb6 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_hb7 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + # BF16 register tensors for vectorized V load and output store (8 elements) + r_v_bf16_vec = cute.make_rmem_tensor( + cute.make_layout((MTP_ILP_ROWS,), stride=(1,)), cutlass.BFloat16 + ) + r_o_bf16_vec = cute.make_rmem_tensor( + cute.make_layout((MTP_ILP_ROWS,), stride=(1,)), cutlass.BFloat16 ) - # Gate 2 - only for NUM_TOKENS >= 3 - g_exp2 = cutlass.Float32(0.0) - beta2 = cutlass.Float32(0.0) - if NUM_TOKENS >= 3: - alpha2 = ga[(batch_idx, 2, value_head_idx)].to(cutlass.Float32) - beta_raw2 = gb[(batch_idx, 2, value_head_idx)].to(cutlass.Float32) - g_exp2, beta2 = compute_single_gate( - alpha2, beta_raw2, dt_bias_val, A_log_val, softplus_beta, softplus_threshold - ) - - # Gate 3 - only for NUM_TOKENS = 4 - g_exp3 = cutlass.Float32(0.0) - beta3 = cutlass.Float32(0.0) - if NUM_TOKENS == 4: - alpha3 = ga[(batch_idx, 3, value_head_idx)].to(cutlass.Float32) - beta_raw3 = gb[(batch_idx, 3, value_head_idx)].to(cutlass.Float32) - g_exp3, beta3 = compute_single_gate( - alpha3, beta_raw3, dt_bias_val, A_log_val, softplus_beta, softplus_threshold - ) - - # Upfront H loading - h_global = gH[(pool_batch_idx, value_head_idx, None, None)] - load_h_chunk_async(h_sh_chunk0, h_global, tidx, 0) - nvvm.cp_async_commit_group() - load_h_chunk_async(h_sh_chunk1, h_global, tidx, 32) - nvvm.cp_async_commit_group() - load_h_chunk_async(h_sh_chunk2, h_global, tidx, 64) - nvvm.cp_async_commit_group() - load_h_chunk_async(h_sh_chunk3, h_global, tidx, 96) - nvvm.cp_async_commit_group() - - # Q/K normalization - tokens 0, 1 always - q_head0 = gQ[(batch_idx, 0, query_head_idx, None)] - k_head0 = gK[(batch_idx, 0, query_head_idx, None)] - q_head1 = gQ[(batch_idx, 1, query_head_idx, None)] - k_head1 = gK[(batch_idx, 1, query_head_idx, None)] - - if warp_idx == 0: - normalize_and_store_qk_to_smem( - q_head0, k_head0, q_sh0, k_sh0, lane_idx, scale, eps - ) - if warp_idx == 1: - normalize_and_store_qk_to_smem( - q_head1, k_head1, q_sh1, k_sh1, lane_idx, scale, eps - ) - - # Token 2 Q/K normalization - only for NUM_TOKENS >= 3 - if NUM_TOKENS >= 3: - q_head2 = gQ[(batch_idx, 2, query_head_idx, None)] - k_head2 = gK[(batch_idx, 2, query_head_idx, None)] - if warp_idx == 2: - normalize_and_store_qk_to_smem( - q_head2, k_head2, q_sh2, k_sh2, lane_idx, scale, eps + # Redirect padding entries (cache_idx < 0) to null buffer (slot 0) + if cache_idx < 0: + cache_idx = cutlass.Int32(0) + + # Process all batch entries (padding slots redirected to slot 0 above) + if cache_idx >= 0: + k_start = lane_in_group * vec_size + + # For T>1: parallel precompute q, k into shared SMEM + # With 4 warps, each pass precomputes up to 4 tokens in parallel. + # For T<=4: 1 pass. For T=5..8: 2 passes. General: ceil(T/4) passes. + if cutlass.const_expr(T > 1): + num_precompute_passes: cutlass.Constexpr[int] = ( + T + num_groups - 1 + ) // num_groups + for pass_idx in cutlass.range_constexpr(num_precompute_passes): + i_t_pre = pass_idx * num_groups + group_idx + if i_t_pre < T: + q_tile_pre = cute.local_tile( + q, (1, 1, 1, vec_size), (i_n, i_t_pre, i_h, lane_in_group) + ) + k_tile_pre = cute.local_tile( + k, (1, 1, 1, vec_size), (i_n, i_t_pre, i_h, lane_in_group) + ) + cute.autovec_copy(q_tile_pre, r_q_bf16) + cute.autovec_copy(k_tile_pre, r_k_bf16) + + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + if cutlass.const_expr(use_qk_l2norm): + sum_q = 0.0 + sum_k = 0.0 + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + inv_norm_q_scaled = ( + cute.rsqrt(sum_q + 1e-6, fastmath=True) * scale + ) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q_scaled + r_k[i] = r_k[i] * inv_norm_k + else: + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale + + # Write to shared SMEM (all active warps write different token slots) + for i in cutlass.range_constexpr(vec_size): + sQ[(i_t_pre, k_start + i)] = r_q[i] + sK[(i_t_pre, k_start + i)] = r_k[i] + + # Precompute g/beta for the assigned token - only for T>2 + if cutlass.const_expr(T > 2): + r_a_pre = cutlass.Float32(a[i_n, i_t_pre, i_hv]) + r_b_pre = cutlass.Float32(b[i_n, i_t_pre, i_hv]) + x_pre = r_a_pre + r_dt_bias + beta_x_pre = softplus_beta * x_pre + exp_beta_x_pre = cute.exp(beta_x_pre, fastmath=True) + softplus_val_pre = ( + cutlass.Float32(1.0) / softplus_beta + ) * cute.log( + cutlass.Float32(1.0) + exp_beta_x_pre, fastmath=True + ) + use_softplus_pre = ( + cutlass.Float32(1.0) + if beta_x_pre <= softplus_threshold + else cutlass.Float32(0.0) + ) + softplus_x_pre = ( + use_softplus_pre * softplus_val_pre + + (cutlass.Float32(1.0) - use_softplus_pre) * x_pre + ) + r_g_value_pre = ( + -cute.exp(r_A_log, fastmath=True) * softplus_x_pre + ) + r_beta_pre = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-r_b_pre, fastmath=True) + ) + r_g_pre = cute.exp(r_g_value_pre, fastmath=True) + if lane_in_group == 0: + sGB[(i_t_pre, 0)] = r_g_pre + sGB[(i_t_pre, 1)] = r_beta_pre + + # Barrier after each pass: all warps must finish writing before next pass reads/writes + cute.arch.barrier() + + # Each group handles tile_v/num_groups V rows, 8 at a time (ILP=8) + flat_state_idx = cache_idx * HV + i_hv + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + eighth_rows: cutlass.Constexpr[int] = rows_per_group // MTP_ILP_ROWS + + # Pre-declare loop-carried variables for dynamic loop compatibility (T>1) + sum_q = cutlass.Float32(0.0) + sum_k = cutlass.Float32(0.0) + inv_norm_q_scaled = cutlass.Float32(1.0) + inv_norm_k = cutlass.Float32(1.0) + + # For T>1: don't unroll row_oct loop (reduces code size for better icache) + # For T=1: fully unroll row_oct loop (no code size issue, max performance) + for row_oct in cutlass.range(eighth_rows, unroll=1, unroll_full=(T <= 1)): + v_base = i_v * tile_v + group_idx * rows_per_group + row_oct * MTP_ILP_ROWS + v0 = v_base + v1 = v_base + 1 + v2 = v_base + 2 + v3 = v_base + 3 + v4 = v_base + 4 + v5 = v_base + 5 + v6 = v_base + 6 + v7 = v_base + 7 + + # Load h for ALL 8 V-rows: GMEM BF16 -> BF16 regs (vectorized) -> FP32 regs + ht0 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v0, lane_in_group) ) - - # Token 3 Q/K normalization - only for NUM_TOKENS = 4 - if NUM_TOKENS == 4: - q_head3 = gQ[(batch_idx, 3, query_head_idx, None)] - k_head3 = gK[(batch_idx, 3, query_head_idx, None)] - if warp_idx == 3: - normalize_and_store_qk_to_smem( - q_head3, k_head3, q_sh3, k_sh3, lane_idx, scale, eps + ht1 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v1, lane_in_group) + ) + ht2 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v2, lane_in_group) + ) + ht3 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v3, lane_in_group) + ) + ht4 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v4, lane_in_group) + ) + ht5 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v5, lane_in_group) + ) + ht6 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v6, lane_in_group) + ) + ht7 = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v7, lane_in_group) ) + cute.autovec_copy(ht0, r_hb0) + cute.autovec_copy(ht1, r_hb1) + cute.autovec_copy(ht2, r_hb2) + cute.autovec_copy(ht3, r_hb3) + cute.autovec_copy(ht4, r_hb4) + cute.autovec_copy(ht5, r_hb5) + cute.autovec_copy(ht6, r_hb6) + cute.autovec_copy(ht7, r_hb7) + + # Convert BF16 -> FP32 for all 8 rows + for i in cutlass.range_constexpr(vec_size): + r_h[0, i] = cutlass.Float32(r_hb0[i]) + r_h[1, i] = cutlass.Float32(r_hb1[i]) + r_h[2, i] = cutlass.Float32(r_hb2[i]) + r_h[3, i] = cutlass.Float32(r_hb3[i]) + r_h[4, i] = cutlass.Float32(r_hb4[i]) + r_h[5, i] = cutlass.Float32(r_hb5[i]) + r_h[6, i] = cutlass.Float32(r_hb6[i]) + r_h[7, i] = cutlass.Float32(r_hb7[i]) + + # Process all T time steps with h in FP32 registers + # For T>1: use dynamic timestep loop to reduce code size (saves icache) + # For T=1: fully unroll timestep loop (minimal overhead, no loop counter) + for i_t in cutlass.range(T, unroll=1, unroll_full=(T <= 1)): + # Load q, k, g, beta - conditionally from SMEM or inline + if cutlass.const_expr(T > 1): + # T>1: read q,k from shared SMEM (pre-computed in parallel) + sQ_tile = cute.local_tile(sQ, (1, vec_size), (i_t, lane_in_group)) + sK_tile = cute.local_tile(sK, (1, vec_size), (i_t, lane_in_group)) + cute.autovec_copy(sQ_tile, r_q) + cute.autovec_copy(sK_tile, r_k) + if cutlass.const_expr(T > 2): + # T>2: read pre-computed g, beta from shared SMEM + r_g = sGB[(i_t, 0)] + r_beta = sGB[(i_t, 1)] + else: + # T=2: compute g, beta inline (avoids SMEM read latency) + r_a_val = cutlass.Float32(a[i_n, i_t, i_hv]) + r_b_val = cutlass.Float32(b[i_n, i_t, i_hv]) + x_val = r_a_val + r_dt_bias + beta_x_val = softplus_beta * x_val + exp_beta_x_val = cute.exp(beta_x_val, fastmath=True) + softplus_val_v = ( + cutlass.Float32(1.0) / softplus_beta + ) * cute.log( + cutlass.Float32(1.0) + exp_beta_x_val, fastmath=True + ) + use_softplus_v = ( + cutlass.Float32(1.0) + if beta_x_val <= softplus_threshold + else cutlass.Float32(0.0) + ) + softplus_x_v = ( + use_softplus_v * softplus_val_v + + (cutlass.Float32(1.0) - use_softplus_v) * x_val + ) + r_g_value_v = -cute.exp(r_A_log, fastmath=True) * softplus_x_v + r_beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-r_b_val, fastmath=True) + ) + r_g = cute.exp(r_g_value_v, fastmath=True) + else: + # T=1: compute inline (no SMEM overhead) + q_tile_t = cute.local_tile( + q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group) + ) + k_tile_t = cute.local_tile( + k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group) + ) + cute.autovec_copy(q_tile_t, r_q_bf16) + cute.autovec_copy(k_tile_t, r_k_bf16) + + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + if cutlass.const_expr(use_qk_l2norm): + sum_q = cutlass.Float32(0.0) + sum_k = cutlass.Float32(0.0) + for i in cutlass.range_constexpr(vec_size): + sum_q += r_q[i] * r_q[i] + sum_k += r_k[i] * r_k[i] + for offset in [16, 8, 4, 2, 1]: + sum_q += cute.arch.shuffle_sync_bfly( + sum_q, offset=offset, mask=-1, mask_and_clamp=31 + ) + sum_k += cute.arch.shuffle_sync_bfly( + sum_k, offset=offset, mask=-1, mask_and_clamp=31 + ) + inv_norm_q_scaled = ( + cute.rsqrt(sum_q + 1e-6, fastmath=True) * scale + ) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q_scaled + r_k[i] = r_k[i] * inv_norm_k + else: + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale + + r_a_val = cutlass.Float32(a[i_n, i_t, i_hv]) + r_b_val = cutlass.Float32(b[i_n, i_t, i_hv]) + x_val = r_a_val + r_dt_bias + beta_x_val = softplus_beta * x_val + exp_beta_x_val = cute.exp(beta_x_val, fastmath=True) + softplus_val_v = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + exp_beta_x_val, fastmath=True + ) + use_softplus_v = ( + cutlass.Float32(1.0) + if beta_x_val <= softplus_threshold + else cutlass.Float32(0.0) + ) + softplus_x_v = ( + use_softplus_v * softplus_val_v + + (cutlass.Float32(1.0) - use_softplus_v) * x_val + ) + r_g_value_v = -cute.exp(r_A_log, fastmath=True) * softplus_x_v + r_beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-r_b_val, fastmath=True) + ) + r_g = cute.exp(r_g_value_v, fastmath=True) + + # Fused: decay h, dot product h@k with conditional dispatch + s0 = 0.0 + s1 = 0.0 + s2 = 0.0 + s3 = 0.0 + s4 = 0.0 + s5 = 0.0 + s6 = 0.0 + s7 = 0.0 + s0b = 0.0 + s1b = 0.0 + s2b = 0.0 + s3b = 0.0 + s4b = 0.0 + s5b = 0.0 + s6b = 0.0 + s7b = 0.0 + for i in cutlass.range_constexpr(0, vec_size, 2): + # Convert + decay for pairs of elements + if cutlass.const_expr(use_packed_fma): + r_h[0, i], r_h[0, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[1, i], r_h[1, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[2, i], r_h[2, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[3, i], r_h[3, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[4, i], r_h[4, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[4, i], r_h[4, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[5, i], r_h[5, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[5, i], r_h[5, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[6, i], r_h[6, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[6, i], r_h[6, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + r_h[7, i], r_h[7, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_h[7, i], r_h[7, i + 1]), + src_b=(r_g, r_g), + src_c=(cutlass.Float32(0.0), cutlass.Float32(0.0)), + ) + else: + r_h[0, i], r_h[0, i + 1] = fma_pair_mul( + r_h[0, i], r_h[0, i + 1], r_g, r_g + ) + r_h[1, i], r_h[1, i + 1] = fma_pair_mul( + r_h[1, i], r_h[1, i + 1], r_g, r_g + ) + r_h[2, i], r_h[2, i + 1] = fma_pair_mul( + r_h[2, i], r_h[2, i + 1], r_g, r_g + ) + r_h[3, i], r_h[3, i + 1] = fma_pair_mul( + r_h[3, i], r_h[3, i + 1], r_g, r_g + ) + r_h[4, i], r_h[4, i + 1] = fma_pair_mul( + r_h[4, i], r_h[4, i + 1], r_g, r_g + ) + r_h[5, i], r_h[5, i + 1] = fma_pair_mul( + r_h[5, i], r_h[5, i + 1], r_g, r_g + ) + r_h[6, i], r_h[6, i + 1] = fma_pair_mul( + r_h[6, i], r_h[6, i + 1], r_g, r_g + ) + r_h[7, i], r_h[7, i + 1] = fma_pair_mul( + r_h[7, i], r_h[7, i + 1], r_g, r_g + ) + # Dot product h@k using paired FMA + if cutlass.const_expr(use_packed_fma): + s0, s0b = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s0, s0b), + ) + s1, s1b = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s1, s1b), + ) + s2, s2b = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s2, s2b), + ) + s3, s3b = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s3, s3b), + ) + s4, s4b = cute.arch.fma_packed_f32x2( + src_a=(r_h[4, i], r_h[4, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s4, s4b), + ) + s5, s5b = cute.arch.fma_packed_f32x2( + src_a=(r_h[5, i], r_h[5, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s5, s5b), + ) + s6, s6b = cute.arch.fma_packed_f32x2( + src_a=(r_h[6, i], r_h[6, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s6, s6b), + ) + s7, s7b = cute.arch.fma_packed_f32x2( + src_a=(r_h[7, i], r_h[7, i + 1]), + src_b=(r_k[i], r_k[i + 1]), + src_c=(s7, s7b), + ) + else: + s0, s0b = fma_pair( + r_h[0, i], r_h[0, i + 1], r_k[i], r_k[i + 1], s0, s0b + ) + s1, s1b = fma_pair( + r_h[1, i], r_h[1, i + 1], r_k[i], r_k[i + 1], s1, s1b + ) + s2, s2b = fma_pair( + r_h[2, i], r_h[2, i + 1], r_k[i], r_k[i + 1], s2, s2b + ) + s3, s3b = fma_pair( + r_h[3, i], r_h[3, i + 1], r_k[i], r_k[i + 1], s3, s3b + ) + s4, s4b = fma_pair( + r_h[4, i], r_h[4, i + 1], r_k[i], r_k[i + 1], s4, s4b + ) + s5, s5b = fma_pair( + r_h[5, i], r_h[5, i + 1], r_k[i], r_k[i + 1], s5, s5b + ) + s6, s6b = fma_pair( + r_h[6, i], r_h[6, i + 1], r_k[i], r_k[i + 1], s6, s6b + ) + s7, s7b = fma_pair( + r_h[7, i], r_h[7, i + 1], r_k[i], r_k[i + 1], s7, s7b + ) + # Combine paired accumulators + s0 = s0 + s0b + s1 = s1 + s1b + s2 = s2 + s2b + s3 = s3 + s3b + s4 = s4 + s4b + s5 = s5 + s5b + s6 = s6 + s6b + s7 = s7 + s7b + + # Interleaved butterfly reduction for 8 s-values + for offset in [16, 8, 4, 2, 1]: + s0 += cute.arch.shuffle_sync_bfly( + s0, offset=offset, mask=-1, mask_and_clamp=31 + ) + s1 += cute.arch.shuffle_sync_bfly( + s1, offset=offset, mask=-1, mask_and_clamp=31 + ) + s2 += cute.arch.shuffle_sync_bfly( + s2, offset=offset, mask=-1, mask_and_clamp=31 + ) + s3 += cute.arch.shuffle_sync_bfly( + s3, offset=offset, mask=-1, mask_and_clamp=31 + ) + s4 += cute.arch.shuffle_sync_bfly( + s4, offset=offset, mask=-1, mask_and_clamp=31 + ) + s5 += cute.arch.shuffle_sync_bfly( + s5, offset=offset, mask=-1, mask_and_clamp=31 + ) + s6 += cute.arch.shuffle_sync_bfly( + s6, offset=offset, mask=-1, mask_and_clamp=31 + ) + s7 += cute.arch.shuffle_sync_bfly( + s7, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Delta rule: v_new = (v - sum_hk) * beta - vectorized V load + vt_slice = cute.local_tile( + v, (1, 1, 1, MTP_ILP_ROWS), (i_n, i_t, i_hv, v_base // MTP_ILP_ROWS) + ) + cute.autovec_copy(vt_slice, r_v_bf16_vec) + vn0 = (cutlass.Float32(r_v_bf16_vec[0]) - s0) * r_beta + vn1 = (cutlass.Float32(r_v_bf16_vec[1]) - s1) * r_beta + vn2 = (cutlass.Float32(r_v_bf16_vec[2]) - s2) * r_beta + vn3 = (cutlass.Float32(r_v_bf16_vec[3]) - s3) * r_beta + vn4 = (cutlass.Float32(r_v_bf16_vec[4]) - s4) * r_beta + vn5 = (cutlass.Float32(r_v_bf16_vec[5]) - s5) * r_beta + vn6 = (cutlass.Float32(r_v_bf16_vec[6]) - s6) * r_beta + vn7 = (cutlass.Float32(r_v_bf16_vec[7]) - s7) * r_beta + + # Rank-1 update + output dot product h@q with conditional dispatch + o0 = 0.0 + o1 = 0.0 + o2 = 0.0 + o3 = 0.0 + o4 = 0.0 + o5 = 0.0 + o6 = 0.0 + o7 = 0.0 + o0b = 0.0 + o1b = 0.0 + o2b = 0.0 + o3b = 0.0 + o4b = 0.0 + o5b = 0.0 + o6b = 0.0 + o7b = 0.0 + for i in cutlass.range_constexpr(0, vec_size, 2): + # Rank-1 update: h += k * vn (paired FMA) + if cutlass.const_expr(use_packed_fma): + r_h[0, i], r_h[0, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn0, vn0), + src_c=(r_h[0, i], r_h[0, i + 1]), + ) + r_h[1, i], r_h[1, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn1, vn1), + src_c=(r_h[1, i], r_h[1, i + 1]), + ) + r_h[2, i], r_h[2, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn2, vn2), + src_c=(r_h[2, i], r_h[2, i + 1]), + ) + r_h[3, i], r_h[3, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn3, vn3), + src_c=(r_h[3, i], r_h[3, i + 1]), + ) + r_h[4, i], r_h[4, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn4, vn4), + src_c=(r_h[4, i], r_h[4, i + 1]), + ) + r_h[5, i], r_h[5, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn5, vn5), + src_c=(r_h[5, i], r_h[5, i + 1]), + ) + r_h[6, i], r_h[6, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn6, vn6), + src_c=(r_h[6, i], r_h[6, i + 1]), + ) + r_h[7, i], r_h[7, i + 1] = cute.arch.fma_packed_f32x2( + src_a=(r_k[i], r_k[i + 1]), + src_b=(vn7, vn7), + src_c=(r_h[7, i], r_h[7, i + 1]), + ) + else: + r_h[0, i], r_h[0, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn0, vn0, r_h[0, i], r_h[0, i + 1] + ) + r_h[1, i], r_h[1, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn1, vn1, r_h[1, i], r_h[1, i + 1] + ) + r_h[2, i], r_h[2, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn2, vn2, r_h[2, i], r_h[2, i + 1] + ) + r_h[3, i], r_h[3, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn3, vn3, r_h[3, i], r_h[3, i + 1] + ) + r_h[4, i], r_h[4, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn4, vn4, r_h[4, i], r_h[4, i + 1] + ) + r_h[5, i], r_h[5, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn5, vn5, r_h[5, i], r_h[5, i + 1] + ) + r_h[6, i], r_h[6, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn6, vn6, r_h[6, i], r_h[6, i + 1] + ) + r_h[7, i], r_h[7, i + 1] = fma_pair( + r_k[i], r_k[i + 1], vn7, vn7, r_h[7, i], r_h[7, i + 1] + ) + # Output dot product: o += h * q (paired FMA) + if cutlass.const_expr(use_packed_fma): + o0, o0b = cute.arch.fma_packed_f32x2( + src_a=(r_h[0, i], r_h[0, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o0, o0b), + ) + o1, o1b = cute.arch.fma_packed_f32x2( + src_a=(r_h[1, i], r_h[1, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o1, o1b), + ) + o2, o2b = cute.arch.fma_packed_f32x2( + src_a=(r_h[2, i], r_h[2, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o2, o2b), + ) + o3, o3b = cute.arch.fma_packed_f32x2( + src_a=(r_h[3, i], r_h[3, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o3, o3b), + ) + o4, o4b = cute.arch.fma_packed_f32x2( + src_a=(r_h[4, i], r_h[4, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o4, o4b), + ) + o5, o5b = cute.arch.fma_packed_f32x2( + src_a=(r_h[5, i], r_h[5, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o5, o5b), + ) + o6, o6b = cute.arch.fma_packed_f32x2( + src_a=(r_h[6, i], r_h[6, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o6, o6b), + ) + o7, o7b = cute.arch.fma_packed_f32x2( + src_a=(r_h[7, i], r_h[7, i + 1]), + src_b=(r_q[i], r_q[i + 1]), + src_c=(o7, o7b), + ) + else: + o0, o0b = fma_pair( + r_h[0, i], r_h[0, i + 1], r_q[i], r_q[i + 1], o0, o0b + ) + o1, o1b = fma_pair( + r_h[1, i], r_h[1, i + 1], r_q[i], r_q[i + 1], o1, o1b + ) + o2, o2b = fma_pair( + r_h[2, i], r_h[2, i + 1], r_q[i], r_q[i + 1], o2, o2b + ) + o3, o3b = fma_pair( + r_h[3, i], r_h[3, i + 1], r_q[i], r_q[i + 1], o3, o3b + ) + o4, o4b = fma_pair( + r_h[4, i], r_h[4, i + 1], r_q[i], r_q[i + 1], o4, o4b + ) + o5, o5b = fma_pair( + r_h[5, i], r_h[5, i + 1], r_q[i], r_q[i + 1], o5, o5b + ) + o6, o6b = fma_pair( + r_h[6, i], r_h[6, i + 1], r_q[i], r_q[i + 1], o6, o6b + ) + o7, o7b = fma_pair( + r_h[7, i], r_h[7, i + 1], r_q[i], r_q[i + 1], o7, o7b + ) + # Combine paired accumulators + o0 = o0 + o0b + o1 = o1 + o1b + o2 = o2 + o2b + o3 = o3 + o3b + o4 = o4 + o4b + o5 = o5 + o5b + o6 = o6 + o6b + o7 = o7 + o7b + + # Start FP32→BF16 conversion for intermediate state BEFORE shuffles + # (overlaps conversion with shuffle pipeline) + if cutlass.const_expr(cache_intermediate_states): + for i in cutlass.range_constexpr(vec_size): + r_hb0[i] = cutlass.BFloat16(r_h[0, i]) + r_hb1[i] = cutlass.BFloat16(r_h[1, i]) + r_hb2[i] = cutlass.BFloat16(r_h[2, i]) + r_hb3[i] = cutlass.BFloat16(r_h[3, i]) + r_hb4[i] = cutlass.BFloat16(r_h[4, i]) + r_hb5[i] = cutlass.BFloat16(r_h[5, i]) + r_hb6[i] = cutlass.BFloat16(r_h[6, i]) + r_hb7[i] = cutlass.BFloat16(r_h[7, i]) + + # Write intermediate state BEFORE output shuffles (issue stores early to overlap with shuffles) + if cutlass.const_expr(cache_intermediate_states): + flat_idx = cache_idx * T * HV + i_t * HV + i_hv + it0 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v0, lane_in_group), + ) + it1 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v1, lane_in_group), + ) + it2 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v2, lane_in_group), + ) + it3 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v3, lane_in_group), + ) + it4 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v4, lane_in_group), + ) + it5 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v5, lane_in_group), + ) + it6 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v6, lane_in_group), + ) + it7 = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v7, lane_in_group), + ) + cute.autovec_copy(r_hb0, it0) + cute.autovec_copy(r_hb1, it1) + cute.autovec_copy(r_hb2, it2) + cute.autovec_copy(r_hb3, it3) + cute.autovec_copy(r_hb4, it4) + cute.autovec_copy(r_hb5, it5) + cute.autovec_copy(r_hb6, it6) + cute.autovec_copy(r_hb7, it7) + + # Interleaved butterfly reduction for 8 o-values + for offset in [16, 8, 4, 2, 1]: + o0 += cute.arch.shuffle_sync_bfly( + o0, offset=offset, mask=-1, mask_and_clamp=31 + ) + o1 += cute.arch.shuffle_sync_bfly( + o1, offset=offset, mask=-1, mask_and_clamp=31 + ) + o2 += cute.arch.shuffle_sync_bfly( + o2, offset=offset, mask=-1, mask_and_clamp=31 + ) + o3 += cute.arch.shuffle_sync_bfly( + o3, offset=offset, mask=-1, mask_and_clamp=31 + ) + o4 += cute.arch.shuffle_sync_bfly( + o4, offset=offset, mask=-1, mask_and_clamp=31 + ) + o5 += cute.arch.shuffle_sync_bfly( + o5, offset=offset, mask=-1, mask_and_clamp=31 + ) + o6 += cute.arch.shuffle_sync_bfly( + o6, offset=offset, mask=-1, mask_and_clamp=31 + ) + o7 += cute.arch.shuffle_sync_bfly( + o7, offset=offset, mask=-1, mask_and_clamp=31 + ) + + # Write output: vectorized BF16 store + if lane_in_group == 0: + r_o_bf16_vec[0] = cutlass.BFloat16(o0) + r_o_bf16_vec[1] = cutlass.BFloat16(o1) + r_o_bf16_vec[2] = cutlass.BFloat16(o2) + r_o_bf16_vec[3] = cutlass.BFloat16(o3) + r_o_bf16_vec[4] = cutlass.BFloat16(o4) + r_o_bf16_vec[5] = cutlass.BFloat16(o5) + r_o_bf16_vec[6] = cutlass.BFloat16(o6) + r_o_bf16_vec[7] = cutlass.BFloat16(o7) + ot_slice = cute.local_tile( + o, + (1, 1, 1, MTP_ILP_ROWS), + (i_n, i_t, i_hv, v_base // MTP_ILP_ROWS), + ) + cute.autovec_copy(r_o_bf16_vec, ot_slice) + + # Write final state back as BF16 (if not disabled) + if cutlass.const_expr(not disable_state_update): + for i in cutlass.range_constexpr(vec_size): + r_hb0[i] = cutlass.BFloat16(r_h[0, i]) + r_hb1[i] = cutlass.BFloat16(r_h[1, i]) + r_hb2[i] = cutlass.BFloat16(r_h[2, i]) + r_hb3[i] = cutlass.BFloat16(r_h[3, i]) + r_hb4[i] = cutlass.BFloat16(r_h[4, i]) + r_hb5[i] = cutlass.BFloat16(r_h[5, i]) + r_hb6[i] = cutlass.BFloat16(r_h[6, i]) + r_hb7[i] = cutlass.BFloat16(r_h[7, i]) + cute.autovec_copy(r_hb0, ht0) + cute.autovec_copy(r_hb1, ht1) + cute.autovec_copy(r_hb2, ht2) + cute.autovec_copy(r_hb3, ht3) + cute.autovec_copy(r_hb4, ht4) + cute.autovec_copy(r_hb5, ht5) + cute.autovec_copy(r_hb6, ht6) + cute.autovec_copy(r_hb7, ht7) - cute.arch.sync_threads() - - # V loading - tokens 0, 1 always - v_head0 = gV[(batch_idx, 0, value_head_idx, None)] - v_head1 = gV[(batch_idx, 1, value_head_idx, None)] - load_v_to_smem(v_head0, v_sh0, tidx) - load_v_to_smem(v_head1, v_sh1, tidx) - - # Token 2 V loading - only for NUM_TOKENS >= 3 - if NUM_TOKENS >= 3: - v_head2 = gV[(batch_idx, 2, value_head_idx, None)] - load_v_to_smem(v_head2, v_sh2, tidx) - - # Token 3 V loading - only for NUM_TOKENS = 4 - if NUM_TOKENS == 4: - v_head3 = gV[(batch_idx, 3, value_head_idx, None)] - load_v_to_smem(v_head3, v_sh3, tidx) - - # Output pointers - tokens 0, 1 always - h_out = gH[(pool_batch_idx, value_head_idx, None, None)] - o_head0 = gO[(batch_idx, 0, value_head_idx, None)] - o_head1 = gO[(batch_idx, 1, value_head_idx, None)] - - # Token 2 output pointer - o_head2 = o_head1 # Default for T=2 - if NUM_TOKENS >= 3: - o_head2 = gO[(batch_idx, 2, value_head_idx, None)] - - # Token 3 output pointer - o_head3 = o_head2 # Default for T=2,3 - if NUM_TOKENS == 4: - o_head3 = gO[(batch_idx, 3, value_head_idx, None)] - - # Process V-CHUNK 0 - nvvm.cp_async_wait_group(3) - cute.arch.sync_threads() - process_vchunk_unified_234( - h_sh_chunk0, - h_sh_chunk0, - h_out, - h_chunk, - kq_chunk, - k_sh0, - k_sh1, - k_sh2, - k_sh3, - q_sh0, - q_sh1, - q_sh2, - q_sh3, - v_sh0, - v_sh1, - v_sh2, - v_sh3, - reduce_sh, - o_head0, - o_head1, - o_head2, - o_head3, - g_exp0, - g_exp1, - g_exp2, - g_exp3, - beta0, - beta1, - beta2, - beta3, - 0, - 0, - cutlass.Int32(0), - tidx, - warp_idx, - lane_idx, - k_base, - NUM_TOKENS, - ) - # Process V-CHUNK 1 - nvvm.cp_async_wait_group(2) - cute.arch.sync_threads() - process_vchunk_unified_234( - h_sh_chunk1, - h_sh_chunk0, - h_out, - h_chunk, - kq_chunk, - k_sh0, - k_sh1, - k_sh2, - k_sh3, - q_sh0, - q_sh1, - q_sh2, - q_sh3, - v_sh0, - v_sh1, - v_sh2, - v_sh3, - reduce_sh, - o_head0, - o_head1, - o_head2, - o_head3, - g_exp0, - g_exp1, - g_exp2, - g_exp3, - beta0, - beta1, - beta2, - beta3, - 32, - 0, - cutlass.Int32(1), - tidx, - warp_idx, - lane_idx, - k_base, - NUM_TOKENS, - ) +# ============================================================================== +# LAUNCH WRAPPER (MTP version) +# ============================================================================== - # Process V-CHUNK 2 - nvvm.cp_async_wait_group(1) - cute.arch.sync_threads() - process_vchunk_unified_234( - h_sh_chunk2, - h_sh_chunk1, - h_out, - h_chunk, - kq_chunk, - k_sh0, - k_sh1, - k_sh2, - k_sh3, - q_sh0, - q_sh1, - q_sh2, - q_sh3, - v_sh0, - v_sh1, - v_sh2, - v_sh3, - reduce_sh, - o_head0, - o_head1, - o_head2, - o_head3, - g_exp0, - g_exp1, - g_exp2, - g_exp3, - beta0, - beta1, - beta2, - beta3, - 64, - 32, - cutlass.Int32(1), - tidx, - warp_idx, - lane_idx, - k_base, - NUM_TOKENS, - ) - # Process V-CHUNK 3 - nvvm.cp_async_wait_group(0) - cute.arch.sync_threads() - process_vchunk_unified_234( - h_sh_chunk3, - h_sh_chunk2, - h_out, - h_chunk, - kq_chunk, - k_sh0, - k_sh1, - k_sh2, - k_sh3, - q_sh0, - q_sh1, - q_sh2, - q_sh3, - v_sh0, - v_sh1, - v_sh2, - v_sh3, - reduce_sh, - o_head0, - o_head1, - o_head2, - o_head3, - g_exp0, - g_exp1, - g_exp2, - g_exp3, - beta0, - beta1, - beta2, - beta3, - 96, - 64, - cutlass.Int32(1), - tidx, - warp_idx, - lane_idx, - k_base, - NUM_TOKENS, +@cute.jit +def run_gdn_decode_bf16state_mtp( + h0_source: cute.Tensor, # [pool_size * HV, V, K] BF16 + intermediate_states: cute.Tensor, # [pool_size * T * HV, V, K] BF16 (or dummy) + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + h0_indices: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + B: cutlass.Constexpr[int], + T: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + tile_v_param: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + disable_state_update: cutlass.Constexpr[bool], + cache_intermediate_states: cutlass.Constexpr[bool], + use_packed_fma: cutlass.Constexpr[bool], + stream: cuda.CUstream, +): + """Launch the MTP kernel for BF16 state.""" + tile_v = tile_v_param + vec_size = MTP_VEC_SIZE + _, v_dim, _k_dim = ( + h0_source.layout.shape[0], + h0_source.layout.shape[1], + h0_source.layout.shape[2], ) - # Final H store - cute.arch.sync_threads() - store_h_smem_to_gmem(h_sh_chunk3, h_out, tidx, 96) + num_v_tiles = cute.ceil_div(v_dim, tile_v) + grid_size = B * HV * num_v_tiles + + # SMEM: for T>1 include shared sQ/sK (1 copy) + sGB; T=1 needs minimal + smem_bytes = 128 # alignment padding + if T > 1: + smem_bytes = ( + 4 * T * (K + 8) # sQ: T × (K+8) × 4 bytes (shared, one copy) + + 4 * T * (K + 8) # sK: same + + 4 * T * 2 # sGB: T × 2 × 4 bytes (shared) + + 128 # alignment padding + ) + + gdn_decode_bf16state_mtp_kernel( + h0_source, + intermediate_states, + vec_size, + num_v_tiles, + tile_v, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, + h0_indices, + softplus_beta, + softplus_threshold, + scale, + HV, + B, + T, + H, + K, + V, + use_qk_l2norm, + disable_state_update, + cache_intermediate_states, + use_packed_fma, + ).launch( + grid=(grid_size, 1, 1), + block=[MTP_NUM_THREADS, 1, 1], + smem=smem_bytes, + stream=stream, + ) # ============================================================================== -# LAUNCH WRAPPERS +# LAUNCH WRAPPER (ILP version) # ============================================================================== @cute.jit -def gated_delta_rule_launch_seqlen1( - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - ma: cute.Tensor, - mb: cute.Tensor, - mA_log: cute.Tensor, - mdt_bias: cute.Tensor, - mH: cute.Tensor, - mH_slot_indices: cute.Tensor, - mO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, +def run_gdn_decode_bf16state_ilp( + h0_source: cute.Tensor, # [B*HV, V, K] BF16 + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + B: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + use_packed_fma: cutlass.Constexpr[bool], + tile_v_param: cutlass.Constexpr[int], stream: cuda.CUstream, ): - batch_size = mQ.shape[0] - HV = mV.shape[2] - - gated_delta_rule_decode_kernel_seqlen1( - mQ, - mK, - mV, - ma, - mb, - mA_log, - mdt_bias, - mH, - mH_slot_indices, - mO, - scale, + """Launch the ILP-optimized kernel for T=1 with large batch sizes.""" + tile_v = tile_v_param + vec_size = VEC_SIZE_ILP + _, v_dim, _k_dim = ( + h0_source.layout.shape[0], + h0_source.layout.shape[1], + h0_source.layout.shape[2], + ) + + num_v_tiles = cute.ceil_div(v_dim, tile_v) + grid_size = B * HV * num_v_tiles + + # SMEM: minimal (direct GMEM access) + smem_bytes = 128 + + gdn_decode_bf16state_ilp_kernel( + h0_source, + vec_size, + num_v_tiles, + tile_v, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, softplus_beta, softplus_threshold, - eps, + scale, + HV, + H, + K, + V, + use_qk_l2norm, + use_packed_fma, ).launch( - grid=[batch_size * HV, 1, 1], - block=[128, 1, 1], + grid=(grid_size, 1, 1), + block=[NUM_THREADS_ILP, 1, 1], + smem=smem_bytes, stream=stream, ) # ============================================================================== -# LOW-BS SEQLEN=1 KERNEL - 1 V-CHUNK PER CTA (T=1, BS<=4) +# LAUNCH WRAPPER (original cp.async pipeline version) # ============================================================================== -@cute.kernel -def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( - gQ: cute.Tensor, - gK: cute.Tensor, - gV: cute.Tensor, - ga: cute.Tensor, - gb: cute.Tensor, - gA_log: cute.Tensor, - gdt_bias: cute.Tensor, - gH: cute.Tensor, - gH_slot_indices: cute.Tensor, - gO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, +@cute.jit +def run_gdn_decode_bf16state_cooprow( + h0_source: cute.Tensor, # [B*HV, V, K] BF16 + A_log: cute.Tensor, + a: cute.Tensor, + dt_bias: cute.Tensor, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + b: cute.Tensor, + o: cute.Tensor, + softplus_beta: cutlass.Constexpr[float], + softplus_threshold: cutlass.Constexpr[float], + scale: cutlass.Constexpr[float], + HV: cutlass.Constexpr[int], + H: cutlass.Constexpr[int], + K: cutlass.Constexpr[int], + V: cutlass.Constexpr[int], + use_qk_l2norm: cutlass.Constexpr[bool], + stream: cuda.CUstream, ): - """ - Seqlen=1 kernel with 1 V-chunk (32 V rows) per CTA. - For T=1, batch_size <= 4: more CTAs per batch*head for better SM utilization. - Grid: batch_idx * HV * 4 + value_head_idx * 4 + v_chunk_idx (0..3). - """ - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - HV = cutlass.Int32(gV.shape[2]) - H = cutlass.Int32(gQ.shape[2]) - - batch_idx = bidx // (HV * 4) - remainder = bidx % (HV * 4) - value_head_idx = remainder // 4 - v_chunk_idx = remainder % 4 - - query_head_idx = value_head_idx // (HV // H) - v_row_base = v_chunk_idx * 32 - pool_batch_idx = gH_slot_indices[batch_idx] - if pool_batch_idx < 0: - pool_batch_idx = cutlass.Int32(0) - - smem = utils.SmemAllocator() - - alpha = ga[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) - beta_raw = gb[(batch_idx, 0, value_head_idx)].to(cutlass.Float32) - A_log_val = gA_log[value_head_idx] - dt_bias_val = gdt_bias[value_head_idx] - g_exp, beta = compute_single_gate( - alpha, beta_raw, dt_bias_val, A_log_val, softplus_beta, softplus_threshold + """Launch the diff-approach kernel for T=1.""" + batch_size, v_dim, _k_dim = ( + h0_source.layout.shape[0], + h0_source.layout.shape[1], + h0_source.layout.shape[2], ) - h_sh_chunk = smem.allocate_tensor( - cutlass.BFloat16, cute.make_layout((32, 128), stride=(H_SMEM_STRIDE, 1)) + # BF16 async copy: 128-bit = 8 BF16 elements per copy + copy_atom = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.BFloat16, + num_bits_per_copy=128, + ) + + # Thread layout: 8 rows × 16 threads/row = 128 threads + # 16 threads × 8 BF16 elements = 128 K-elements per row + # 8 rows = TILE_V rows per copy (covers full tile in one shot) + thread_layout = cute.make_layout( + (8, 16), + stride=(16, 1), ) + val_layout = cute.make_layout((1, 8)) # 8 BF16 elements per copy = 128 bits + + tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) - q_sh = smem.allocate_tensor(cutlass.Float32, 128) - k_sh = smem.allocate_tensor(cutlass.Float32, 128) + num_v_tiles = cute.ceil_div(v_dim, TILE_V) - pred_sh = smem.allocate_tensor( - cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + vec_size = TILE_K // 32 # = 4 + + # SMEM layout: (TILE_V, TILE_K, NUM_STAGES) in BF16 + smem_layout_staged = cute.make_layout( + (TILE_V, TILE_K, NUM_STAGES), stride=(TILE_K, 1, TILE_V * TILE_K) ) - out_sh = smem.allocate_tensor( - cutlass.Float32, cute.make_layout((32, 4), stride=(1, 32)) + + # SMEM: sData (BF16) + sV (FP32) + sOutput (BF16) + smem_bytes = ( + 2 * TILE_V * TILE_K * NUM_STAGES # sData: BF16 + + 4 * v_dim # sV: FP32 + + 2 * v_dim # sOutput: BF16 + + 128 # alignment padding ) - h_global = gH[(pool_batch_idx, value_head_idx, None, None)] + gdn_decode_bf16state_cooprow_kernel( + tiled_copy_load, + h0_source, + smem_layout_staged, + vec_size, + num_v_tiles, + A_log, + a, + dt_bias, + q, + k, + v, + b, + o, + softplus_beta, + softplus_threshold, + scale, + HV, + H, + K, + V, + use_qk_l2norm, + ).launch( + grid=(batch_size * NUM_BLOCKS_PER_STATE, 1, 1), + block=[NUM_THREADS, 1, 1], + smem=smem_bytes, + stream=stream, + ) - load_h_chunk_async(h_sh_chunk, h_global, tidx, v_row_base) - nvvm.cp_async_commit_group() - q_head = gQ[(batch_idx, 0, query_head_idx, None)] - k_head = gK[(batch_idx, 0, query_head_idx, None)] +# ============================================================================== +# PUBLIC API +# ============================================================================== +_compiled_kernels: dict = {} +_compiled_kernels_ilp: dict = {} - warp_idx = tidx // 32 - lane_idx = tidx % 32 +# Batch size threshold for ILP kernel dispatch +ILP_BATCH_THRESHOLD = 16 # Use ILP kernel for B >= 16 - if warp_idx == 0: - normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale, eps) +# Number of SMs on target GPU (detected dynamically) +NUM_SMS = torch.cuda.get_device_properties(0).multi_processor_count - cute.arch.sync_threads() - v_head = gV[(batch_idx, 0, value_head_idx, None)] - v_sh = smem.allocate_tensor(cutlass.Float32, 32) - if tidx < 32: - v_sh[tidx] = v_head[v_row_base + tidx].to(cutlass.Float32) +def _select_tile_v_for_batch(B: int, HV: int, V: int) -> int: + """Select optimal tile_v for the ILP kernel based on batch size. - h_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) - k_chunk = cute.make_rmem_tensor((32,), cutlass.Float32) - qk_temp = cute.make_rmem_tensor((32,), cutlass.Float32) + Goal: maximize GPU occupancy by ensuring enough blocks to fill all SMs. + Each block handles tile_v V-rows, grid = B * HV * (V / tile_v). + We want at least ~4 waves (4 * NUM_SMS blocks) for good occupancy, + since register pressure limits per-SM occupancy. - k_base = warp_idx * 32 + tile_v must be a multiple of 32 (4 groups * ILP_ROWS=8) and divide V=128. + Valid values: 32, 64, 128. + """ + for tv in [128, 64, 32]: + num_v_tiles = V // tv + grid_size = B * HV * num_v_tiles + # Want at least 4 waves for good occupancy (register pressure limits to ~3 blocks/SM) + if grid_size >= 4 * NUM_SMS: + return tv + return 32 # Minimum tile_v for maximum parallelism - for i in cutlass.range_constexpr(32): - k_chunk[i] = k_sh[k_base + i] - h_out = gH[(pool_batch_idx, value_head_idx, None, None)] - o_head = gO[(batch_idx, 0, value_head_idx, None)] +def gated_delta_rule( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + q: Optional[torch.Tensor] = None, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + b: Optional[torch.Tensor] = None, + initial_state_source: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = True, + scale: Optional[float] = None, +) -> torch.Tensor: + """ + GDN decode T=1 with BF16 state. - nvvm.cp_async_wait_group(0) - cute.arch.sync_threads() + Args: + A_log: [HV] float32 + a: [B, 1, HV] bf16 + dt_bias: [HV] float32 + q: [B, 1, H, K] bf16 + k: [B, 1, H, K] bf16 + v: [B, 1, HV, V] bf16 + b: [B, 1, HV] bf16 + initial_state_source: [B, HV, V, K] bf16 (modified in-place) + scale: Optional, default 1/sqrt(K) - pred = cutlass.Float32(0.0) - pred2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair_mul( - h_sh_chunk[lane_idx, k_base + i].to(cutlass.Float32), - h_sh_chunk[lane_idx, k_base + i + 1].to(cutlass.Float32), - g_exp, - g_exp, - ) - for i in cutlass.range_constexpr(0, 32, 2): - pred, pred2 = fma_pair( - h_chunk[i], h_chunk[i + 1], k_chunk[i], k_chunk[i + 1], pred, pred2 - ) - pred = pred + pred2 - - pred_sh[lane_idx, warp_idx] = pred - cute.arch.sync_threads() - pred_final = ( - pred_sh[lane_idx, 0] - + pred_sh[lane_idx, 1] - + pred_sh[lane_idx, 2] - + pred_sh[lane_idx, 3] - ) + Returns: + output: [B, 1, HV, V] bf16 + """ + global _compiled_kernels_ilp - v_val = (v_sh[lane_idx] - pred_final) * beta + assert q is not None and k is not None and v is not None + assert b is not None and initial_state_source is not None - for i in cutlass.range_constexpr(0, 32, 2): - h_chunk[i], h_chunk[i + 1] = fma_pair( - k_chunk[i], k_chunk[i + 1], v_val, v_val, h_chunk[i], h_chunk[i + 1] - ) + B, T, H, K = q.shape + assert T == 1, f"This kernel only supports T=1, got T={T}" + HV = v.shape[2] + V = v.shape[3] + assert K == 128 and V == 128, f"K and V must be 128, got K={K}, V={V}" + assert initial_state_source.dtype == torch.bfloat16 - for i in cutlass.range_constexpr(32): - qk_temp[i] = q_sh[k_base + i] + if scale is None: + scale = 1.0 / math.sqrt(K) - out = cutlass.Float32(0.0) - out2 = cutlass.Float32(0.0) - for i in cutlass.range_constexpr(0, 32, 2): - out, out2 = fma_pair( - h_chunk[i], h_chunk[i + 1], qk_temp[i], qk_temp[i + 1], out, out2 + # Small batch: route through MTP kernel (T=1 path) with identity indices. + # The cooprow kernel has known correctness issues at small batch sizes (e.g. B=2). + # The MTP kernel's T=1 path uses the same ILP-style computation and is well-tested. + if B < ILP_BATCH_THRESHOLD: + return gated_delta_rule_mtp( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + q=q, + k=k, + v=v, + b=b, + initial_state_source=initial_state_source, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + scale=scale, ) - out = out + out2 - - out_sh[lane_idx, warp_idx] = out - cute.arch.sync_threads() - out_final = ( - out_sh[lane_idx, 0] - + out_sh[lane_idx, 1] - + out_sh[lane_idx, 2] - + out_sh[lane_idx, 3] - ) - - write_h_chunk_to_smem(h_chunk, h_sh_chunk, lane_idx, k_base) - if warp_idx == 0: - o_head[v_row_base + lane_idx] = out_final.to(cutlass.BFloat16) - cute.arch.sync_threads() - store_h_smem_to_gmem(h_sh_chunk, h_out, tidx, v_row_base) + output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + # Reshape state to [B*HV, V, K] + h0_source = initial_state_source.reshape(B * HV, V, K) -@cute.jit -def gated_delta_rule_launch_seqlen1_lowBS_1chunk( - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - ma: cute.Tensor, - mb: cute.Tensor, - mA_log: cute.Tensor, - mdt_bias: cute.Tensor, - mH: cute.Tensor, - mH_slot_indices: cute.Tensor, - mO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, - stream: cuda.CUstream, -): - """Launch LowBS-1 kernel: 4 CTAs per (batch, value_head).""" - batch_size = mQ.shape[0] - HV = mV.shape[2] - - gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( - mQ, - mK, - mV, - ma, - mb, - mA_log, - mdt_bias, - mH, - mH_slot_indices, - mO, - scale, - softplus_beta, - softplus_threshold, - eps, - ).launch( - grid=[batch_size * HV * 4, 1, 1], - block=[128, 1, 1], - stream=stream, - ) + q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) + k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) + v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) + a_ = from_dlpack(a, assumed_align=32, enable_tvm_ffi=True) + b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) + A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) + dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) + h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) + o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) -@cute.jit -def gated_delta_rule_launch_seqlen2( - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - ma: cute.Tensor, - mb: cute.Tensor, - mA_log: cute.Tensor, - mdt_bias: cute.Tensor, - mH: cute.Tensor, - mH_slot_indices: cute.Tensor, - mO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, - stream: cuda.CUstream, -): - batch_size = mQ.shape[0] - HV = mV.shape[2] - - gated_delta_rule_decode_kernel_seqlen234_unified( - mQ, - mK, - mV, - ma, - mb, - mA_log, - mdt_bias, - mH, - mH_slot_indices, - mO, + major, _ = torch.cuda.get_device_capability(q.device) + use_packed_fma = major >= 10 + + # B >= ILP_BATCH_THRESHOLD (small B handled by MTP path above) + tile_v = _select_tile_v_for_batch(B, HV, V) + cache_key = ( + "ilp", + B, + H, + HV, + K, + V, + tile_v, scale, softplus_beta, softplus_threshold, - eps, - 2, # NUM_TOKENS=2 - ).launch( - grid=[batch_size * HV, 1, 1], - block=[128, 1, 1], - stream=stream, + use_packed_fma, ) + if cache_key not in _compiled_kernels_ilp: + # Use maxrregcount=64 for smaller tile_v to improve occupancy + # when grid size is small (fewer waves) + if tile_v < 128: + compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3 --ptxas-options=-maxrregcount=64" + else: + compile_opts = "--enable-tvm-ffi --generate-line-info --opt-level 3" + _compiled_kernels_ilp[cache_key] = cute.compile( + run_gdn_decode_bf16state_ilp, + h_, + A_log_, + a_, + dt_bias_, + q_, + k_, + v_, + b_, + o_, + softplus_beta, + softplus_threshold, + scale, + HV, + B, + H, + K, + V, + use_qk_l2norm_in_kernel, + use_packed_fma, + tile_v, + stream, + options=compile_opts, + ) - -@cute.jit -def gated_delta_rule_launch_seqlen3( - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - ma: cute.Tensor, - mb: cute.Tensor, - mA_log: cute.Tensor, - mdt_bias: cute.Tensor, - mH: cute.Tensor, - mH_slot_indices: cute.Tensor, - mO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, - stream: cuda.CUstream, -): - batch_size = mQ.shape[0] - HV = mV.shape[2] - - gated_delta_rule_decode_kernel_seqlen234_unified( - mQ, - mK, - mV, - ma, - mb, - mA_log, - mdt_bias, - mH, - mH_slot_indices, - mO, - scale, - softplus_beta, - softplus_threshold, - eps, - 3, # NUM_TOKENS=3 - ).launch( - grid=[batch_size * HV, 1, 1], - block=[128, 1, 1], - stream=stream, + _compiled_kernels_ilp[cache_key]( + h_, + A_log_, + a_, + dt_bias_, + q_, + k_, + v_, + b_, + o_, + stream, ) - -@cute.jit -def gated_delta_rule_launch_seqlen4( - mQ: cute.Tensor, - mK: cute.Tensor, - mV: cute.Tensor, - ma: cute.Tensor, - mb: cute.Tensor, - mA_log: cute.Tensor, - mdt_bias: cute.Tensor, - mH: cute.Tensor, - mH_slot_indices: cute.Tensor, - mO: cute.Tensor, - scale: cutlass.Float32, - softplus_beta: cutlass.Float32, - softplus_threshold: cutlass.Float32, - eps: cutlass.Float32, - stream: cuda.CUstream, -): - batch_size = mQ.shape[0] - HV = mV.shape[2] - - gated_delta_rule_decode_kernel_seqlen234_unified( - mQ, - mK, - mV, - ma, - mb, - mA_log, - mdt_bias, - mH, - mH_slot_indices, - mO, - scale, - softplus_beta, - softplus_threshold, - eps, - 4, # NUM_TOKENS=4 - ).launch( - grid=[batch_size * HV, 1, 1], - block=[128, 1, 1], - stream=stream, - ) + return output # ============================================================================== -# KERNEL CLASS +# MTP PUBLIC API # ============================================================================== +_compiled_kernels_mtp: dict = {} -class GatedDeltaRuleKernel: - """ - Gated Delta Rule Kernel for linear attention decode. - - This kernel implements the Gated Delta Rule mechanism supporting sequence - lengths T=1, T=2, T=3, T=4 with optimized CUDA implementations. +def _select_tile_v_for_mtp(B: int, HV: int, V: int, T: int = 1) -> int: + """Select optimal tile_v for the MTP BF16 kernel based on batch size and T. - Key features: - - T=1: Persistent K in registers with aggressive pipelining - - T=2/3/4: Unified kernel with compile-time Constexpr specialization - - L2-normalized Q/K with configurable scale - - Gated exponential decay via softplus - - Bank-conflict-free cross-warp reductions - - Async H memory loading + tile_v must be a multiple of MTP_ILP_ROWS * 4 (= 32) and divide V=128. + Valid values: 32, 64, 128. + With ILP=8, minimum tile_v = 4 * 8 = 32 (4 groups * 8 ILP_ROWS). - Args: - seq_len: Sequence length (1, 2, 3, or 4) + For large batch sizes, use larger tile_v to reduce block count and overhead. """ - - def __init__(self, seq_len: int): - assert seq_len in [1, 2, 3, 4], f"Supported seq_len: 1,2,3,4, got {seq_len}" - self.seq_len = seq_len - self._compiled_kernel = None - - def _get_launch_fn(self): - if self.seq_len == 1: - return gated_delta_rule_launch_seqlen1 - elif self.seq_len == 2: - return gated_delta_rule_launch_seqlen2 - elif self.seq_len == 3: - return gated_delta_rule_launch_seqlen3 - else: - return gated_delta_rule_launch_seqlen4 + for tv in [128, 64, 32]: + num_v_tiles = V // tv + grid_size = B * HV * num_v_tiles + # Want at least 4 waves for good occupancy + if grid_size >= 4 * NUM_SMS: + return tv + return 32 # Minimum tile_v for maximum parallelism -# ============================================================================== -# PUBLIC API -# ============================================================================== - -_compiled_kernels = {} # Cache: (seqlen, batch_size) -> compiled kernel - - -def gated_delta_rule( +def gated_delta_rule_mtp( A_log: torch.Tensor, a: torch.Tensor, dt_bias: torch.Tensor, @@ -1924,84 +2467,82 @@ def gated_delta_rule( b: Optional[torch.Tensor] = None, initial_state_source: Optional[torch.Tensor] = None, initial_state_indices: Optional[torch.Tensor] = None, + intermediate_states_buffer: Optional[torch.Tensor] = None, + disable_state_update: bool = False, use_qk_l2norm_in_kernel: bool = True, scale: Optional[float] = None, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Gated Delta Rule linear attention kernel. - - Implements the Gated Delta Rule mechanism for decode-phase inference, - supporting sequence lengths T=1, T=2, T=3, T=4. + GDN MTP (Multiple Token Processing) with BF16 state. + Processes T tokens sequentially, keeping h in FP32 registers. + H state loaded/stored as BF16. Args: - A_log: Log decay parameter [HV] - a: Alpha gate input [B, T, HV] - dt_bias: Delta-t bias [HV] - softplus_beta: Softplus beta parameter (default: 1.0) - softplus_threshold: Softplus threshold (default: 20.0) - q: Query tensor [B, T, H, K] - k: Key tensor [B, T, H, K] - v: Value tensor [B, T, HV, V] - b: Beta gate input [B, T, HV] - initial_state_source: H state [pool_size, HV, V, K] (K-fast layout), modified in-place. - For the direct path (no pool), pass [B, HV, V, K] and omit initial_state_indices. - initial_state_indices: Per-batch indices [B] (int32) mapping each batch entry to its - slot in initial_state_source. When None, uses identity mapping (arange(B)). - use_qk_l2norm_in_kernel: Whether to L2-normalize Q/K in kernel (default: True) - scale: Optional attention scale (default: 1/sqrt(K)) + A_log: [HV] float32 + a: [B, T, HV] bf16 + dt_bias: [HV] float32 + q: [B, T, H, K] bf16 + k: [B, T, H, K] bf16 + v: [B, T, HV, V] bf16 + b: [B, T, HV] bf16 + initial_state_source: [pool_size, HV, V, K] bf16 + initial_state_indices: [B] int32 - indices into state pool + intermediate_states_buffer: Optional [pool_size, T, HV, V, K] bf16 + disable_state_update: bool - if True, don't update initial state + scale: Optional, default 1/sqrt(K) + output: Optional pre-allocated output tensor [B, T, HV, V] bf16 Returns: - output: [B, T, HV, V] - - Example: - >>> B, T, H, K = 16, 1, 16, 128 - >>> HV, V = 32, 128 - >>> q = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) - >>> k = torch.randn(B, T, H, K, device='cuda', dtype=torch.bfloat16) - >>> v = torch.randn(B, T, HV, V, device='cuda', dtype=torch.bfloat16) - >>> a = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) - >>> b = torch.randn(B, T, HV, device='cuda', dtype=torch.bfloat16) - >>> A_log = torch.randn(HV, device='cuda', dtype=torch.float32) - >>> dt_bias = torch.randn(HV, device='cuda', dtype=torch.float32) - >>> h_state = torch.randn(B, HV, V, K, device='cuda', dtype=torch.bfloat16) - >>> output = gated_delta_rule( - ... A_log, a, dt_bias, q=q, k=k, v=v, b=b, - ... initial_state_source=h_state - ... ) + output: [B, T, HV, V] bf16 """ - global _compiled_kernels - - # Validate required Optional parameters - if q is None: - raise ValueError("q (query tensor) is required") - if k is None: - raise ValueError("k (key tensor) is required") - if v is None: - raise ValueError("v (value tensor) is required") - if b is None: - raise ValueError("b (beta gate tensor) is required") - if initial_state_source is None: - raise ValueError("initial_state_source (H state tensor) is required") + global _compiled_kernels_mtp + + assert q is not None and k is not None and v is not None + assert b is not None and initial_state_source is not None B, T, H, K = q.shape - assert T in [1, 2, 3, 4], f"Supported T=1,2,3,4, got T={T}" HV = v.shape[2] V = v.shape[3] pool_size = initial_state_source.shape[0] + assert K == 128 and V == 128, f"K and V must be 128, got K={K}, V={V}" + assert initial_state_source.dtype == torch.bfloat16 if scale is None: scale = 1.0 / math.sqrt(K) - # Resolve indices: identity mapping when not provided if initial_state_indices is None: - h_slot_indices = torch.arange(B, dtype=torch.int32, device=q.device) - elif initial_state_indices.dtype != torch.int32: - h_slot_indices = initial_state_indices.to(torch.int32) + initial_state_indices = torch.arange(B, dtype=torch.int32, device=q.device) + + if output is None: + output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + + # Reshape state to [pool_size * HV, V, K] + h0_source = initial_state_source.reshape(pool_size * HV, V, K) + + # Handle intermediate states + cache_intermediate_states = intermediate_states_buffer is not None + if cache_intermediate_states: + buffer_size = intermediate_states_buffer.shape[0] + cache_steps = intermediate_states_buffer.shape[1] + assert cache_steps >= T, ( + f"intermediate_states_buffer dim 1 ({cache_steps}) must be >= T={T}" + ) + assert intermediate_states_buffer.dtype == torch.bfloat16 + intermediate_states = intermediate_states_buffer.reshape( + buffer_size * cache_steps * HV, V, K + ) + if not intermediate_states.is_contiguous(): + intermediate_states = intermediate_states.contiguous() else: - h_slot_indices = initial_state_indices + intermediate_states = h0_source[ + :1, :1, :1 + ] # Reuse existing allocation as dummy - output = torch.empty(B, T, HV, V, device=q.device, dtype=q.dtype) + tile_v = _select_tile_v_for_mtp(B, HV, V, T) + h_ = from_dlpack(h0_source, assumed_align=32, enable_tvm_ffi=True) + inter_ = from_dlpack(intermediate_states, assumed_align=32, enable_tvm_ffi=True) q_ = from_dlpack(q, assumed_align=32, enable_tvm_ffi=True) k_ = from_dlpack(k, assumed_align=32, enable_tvm_ffi=True) v_ = from_dlpack(v, assumed_align=32, enable_tvm_ffi=True) @@ -2009,69 +2550,82 @@ def gated_delta_rule( b_ = from_dlpack(b, assumed_align=32, enable_tvm_ffi=True) A_log_ = from_dlpack(A_log, assumed_align=32, enable_tvm_ffi=True) dt_bias_ = from_dlpack(dt_bias, assumed_align=32, enable_tvm_ffi=True) - h_ = from_dlpack(initial_state_source, assumed_align=32, enable_tvm_ffi=True) - h_slot_indices_ = from_dlpack(h_slot_indices, assumed_align=32, enable_tvm_ffi=True) o_ = from_dlpack(output, assumed_align=32, enable_tvm_ffi=True) - - scale_f32 = cutlass.Float32(scale) - softplus_beta_f32 = cutlass.Float32(softplus_beta) - softplus_threshold_f32 = cutlass.Float32(softplus_threshold) - eps_f32 = cutlass.Float32(1e-6) + h0_idx_ = from_dlpack(initial_state_indices, assumed_align=32, enable_tvm_ffi=True) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Check cache - include pool_size so pool and direct paths don't collide - cache_key = (T, B, H, HV, K, V, pool_size) - if cache_key not in _compiled_kernels: - # Select and compile the appropriate kernel - if T == 1 and B <= 4: - launch_fn = gated_delta_rule_launch_seqlen1_lowBS_1chunk - elif T == 1: - launch_fn = gated_delta_rule_launch_seqlen1 - elif T == 2: - launch_fn = gated_delta_rule_launch_seqlen2 - elif T == 3: - launch_fn = gated_delta_rule_launch_seqlen3 - else: # T == 4 - launch_fn = gated_delta_rule_launch_seqlen4 - - _compiled_kernels[cache_key] = cute.compile( - launch_fn, + major, _ = torch.cuda.get_device_capability(q.device) + use_packed_fma = major >= 10 + + cache_key = ( + "mtp_bf16", + B, + T, + H, + HV, + K, + V, + pool_size, + tile_v, + disable_state_update, + cache_intermediate_states, + use_qk_l2norm_in_kernel, + scale, + softplus_beta, + softplus_threshold, + use_packed_fma, + ) + if cache_key not in _compiled_kernels_mtp: + _compiled_kernels_mtp[cache_key] = cute.compile( + run_gdn_decode_bf16state_mtp, + h_, + inter_, + A_log_, + a_, + dt_bias_, q_, k_, v_, - a_, b_, - A_log_, - dt_bias_, - h_, - h_slot_indices_, o_, - scale_f32, - softplus_beta_f32, - softplus_threshold_f32, - eps_f32, + h0_idx_, + softplus_beta, + softplus_threshold, + scale, + HV, + B, + T, + H, + K, + V, + tile_v, + use_qk_l2norm_in_kernel, + disable_state_update, + cache_intermediate_states, + use_packed_fma, stream, - options="--enable-tvm-ffi --generate-line-info", + options="--enable-tvm-ffi --generate-line-info --opt-level 3", ) - # Execute - _compiled_kernels[cache_key]( + _compiled_kernels_mtp[cache_key]( + h_, + inter_, + A_log_, + a_, + dt_bias_, q_, k_, v_, - a_, b_, - A_log_, - dt_bias_, - h_, - h_slot_indices_, o_, - scale_f32, - softplus_beta_f32, - softplus_threshold_f32, - eps_f32, + h0_idx_, stream, ) return output + + +# Backward-compatible aliases +gated_delta_rule_bf16state_cooprow = gated_delta_rule +gated_delta_rule_bf16state_cooprow_mtp = gated_delta_rule_mtp diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 7661f3016f..1b43a0ddfe 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -42,15 +42,16 @@ ) from flashinfer.utils import get_compute_capability -# Import the gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last layout) +# Import BF16 state kernels (T=1 and MTP) try: from flashinfer.gdn_kernels.gdn_decode_bf16_state import ( - gated_delta_rule as gdn_decode_klast_bf16_state, + gated_delta_rule as gdn_decode_bf16_state, + gated_delta_rule_mtp as gdn_decode_bf16_state_mtp, ) - GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = True + GDN_DECODE_BF16_STATE_AVAILABLE = True except ImportError: - GDN_DECODE_KLAST_BF16_STATE_AVAILABLE = False + GDN_DECODE_BF16_STATE_AVAILABLE = False def _skip_if_not_sm90_or_later(): @@ -60,9 +61,71 @@ def _skip_if_not_sm90_or_later(): pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}") +def _assert_close_large_tensor( + actual: torch.Tensor, + expected: torch.Tensor, + atol: float, + rtol: float, + msg: str, + timestep_dim: int | None = None, +): + """Manual assert_close for large tensors that avoids RuntimeError in error formatting. + + torch.testing.assert_close crashes with RuntimeError when trying to format + error messages for tensors with >1B elements. This function computes the + comparison manually and reports per-timestep error diagnostics on failure. + """ + # Compare per-slice to avoid allocating huge temporary tensors + if timestep_dim is not None and actual.ndim > timestep_dim: + T = actual.shape[timestep_dim] + per_t_stats = [] + any_violation = False + for t in range(T): + diff_t = ( + actual.select(timestep_dim, t).float() + - expected.select(timestep_dim, t).float() + ).abs() + tol_t = atol + rtol * expected.select(timestep_dim, t).float().abs() + violations_t = diff_t > tol_t + count = violations_t.sum().item() + total = violations_t.numel() + per_t_stats.append( + (t, diff_t.max().item(), diff_t.mean().item(), count, total) + ) + if count > 0: + any_violation = True + del diff_t, tol_t, violations_t + + if not any_violation: + return + + lines = [msg] + for t, t_max, t_mean, t_count, t_total in per_t_stats: + lines.append( + f" t={t}: max_abs={t_max:.6f}, mean={t_mean:.6f}, " + f"violations={t_count}/{t_total} ({100 * t_count / t_total:.4f}%)" + ) + lines.append(f" Tolerances: atol={atol}, rtol={rtol}") + raise AssertionError("\n".join(lines)) + else: + diff = (actual.float() - expected.float()).abs() + tol = atol + rtol * expected.float().abs() + violations = diff > tol + if not violations.any(): + return + num_violations = violations.sum().item() + total = violations.numel() + raise AssertionError( + f"{msg}\n" + f" Max abs error: {diff.max().item():.6f}, " + f"Violations: {num_violations}/{total} ({100 * num_violations / total:.4f}%), " + f"Tolerances: atol={atol}, rtol={rtol}" + ) + + # ============================================================================ # Test decode kernel with pretranspose version ([B*HV, V, K]) -# Reference: fp32 h state (default); bf16 h state used only for gdn_decode_klast_bf16_state. +# Reference: fp32 h state (default); bf16 h state used only for gdn_decode_bf16_state. # ============================================================================ @@ -202,7 +265,6 @@ def _test_decode_kernel_pretranspose( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_basic_pretranspose( @@ -368,7 +430,6 @@ def _test_decode_kernel_nontranspose( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_basic_nontranspose( @@ -513,7 +574,6 @@ def _test_decode_kernel_pretranspose_pool( @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_pretranspose_pool( @@ -770,7 +830,6 @@ def _test_decode_kernel_pretranspose_pool_all_padding( @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_pretranspose_pool_negative_indices( @@ -798,7 +857,6 @@ def test_decode_kernel_pretranspose_pool_negative_indices( @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_pretranspose_pool_all_padding( @@ -1132,12 +1190,15 @@ def _test_verify_kernel_mtp( -2, -1 ) # [pool_size, T, HV, K, V] - torch.testing.assert_close( + # Use manual comparison to avoid RuntimeError from torch.testing.assert_close + # when formatting error messages for tensors with >1B elements (e.g. [512, 5, 32, 128, 128]) + _assert_close_large_tensor( intermediate_states_kernel.float(), intermediate_states_ref.float(), atol=atol_s, rtol=rtol_s, msg=f"Intermediate states mismatch for MTP kernel (B={B}, T={T}, dtype={dtype})", + timestep_dim=1, ) # Compare final state if state update is enabled @@ -1165,7 +1226,6 @@ def _test_verify_kernel_mtp( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_verify_kernel_mtp( @@ -1206,7 +1266,6 @@ def test_verify_kernel_mtp( @pytest.mark.parametrize("seq_len", [2, 3, 4, 5, 6, 7, 8]) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_mtp_fp32_state_with_cache_and_state_update( @@ -1244,43 +1303,41 @@ def test_mtp_fp32_state_with_cache_and_state_update( # ============================================================================ -# Test gdn_decode_klast_bf16_state kernel (T=1..4, bf16 state, K-last) +# Test BF16 state kernel (T=1..4, bf16 state, K-last) # Reference: bf16 h state only here (state_dtype=torch.bfloat16). Other kernels # above use fp32 h state reference. # ============================================================================ -def _test_gdn_decode_klast_bf16_state_kernel( +def _test_gdn_decode_bf16_state_kernel( dtype: str, batch_size: int, num_q_heads: int, num_k_heads: int, num_v_heads: int, head_size: int, - seq_len: int, # T=1,2,3,4 + seq_len: int, scale: float, alpha: bool, beta: bool, seed: int | None = None, ): - """Test gdn_decode_klast_bf16_state kernel for T=1,2,3,4 with bf16 h state. + """Test BF16 state kernel with bf16 h state. Both kernel and reference use bf16 h state: reference runs with state_dtype=torch.bfloat16 (read h as fp32, compute in fp32, store h in bf16) - so the comparison is apples-to-apples with the gdn_decode_klast_bf16_state kernel. + so the comparison is apples-to-apples with the BF16 state kernel. """ _skip_if_not_sm90_or_later() - if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: - pytest.skip("gdn_decode_klast_bf16_state kernel not available") + if not GDN_DECODE_BF16_STATE_AVAILABLE: + pytest.skip("BF16 state kernel not available") random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - assert seq_len in [1, 2, 3, 4], ( - f"gdn_decode_klast_bf16_state supports T=1,2,3,4, got T={seq_len}" - ) + assert seq_len >= 1, f"seq_len must be >= 1, got T={seq_len}" # State and GDN parameters are based on num_v_heads (HV in kernel API) num_sab_heads = num_v_heads @@ -1297,7 +1354,7 @@ def _test_gdn_decode_klast_bf16_state_kernel( # NOTE: Do NOT pre-normalize K here. Both the kernel (use_qk_l2norm_in_kernel=True) # and reference will apply L2 normalization internally after GQA expansion. - # gdn_decode_klast_bf16_state kernel expects [B, HV, V, K] (K-fast layout) in BF16. + # BF16 state kernel expects [B, HV, V, K] (K-fast layout) in BF16. # Use the same bf16 initial state for both kernel and reference so we # compare the bf16 h state path. input_state_kernel = torch.randn( @@ -1311,7 +1368,7 @@ def _test_gdn_decode_klast_bf16_state_kernel( # A_log: log decay parameter [HV] - must be float32 A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 - # dt_bias: decay bias [HV] - must be float32 for gdn_decode_klast_bf16_state kernel + # dt_bias: decay bias [HV] - must be float32 for BF16 state kernel dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 # a: input-dependent decay [B, T, HV] @@ -1335,9 +1392,10 @@ def _test_gdn_decode_klast_bf16_state_kernel( * 10.0 ) - # Call gdn_decode_klast_bf16_state kernel + # Call BF16 state kernel (T=1 uses gated_delta_rule, T>1 uses MTP) our_state = input_state_kernel.clone() - our_o = gdn_decode_klast_bf16_state( + kernel_fn = gdn_decode_bf16_state if seq_len == 1 else gdn_decode_bf16_state_mtp + our_o = kernel_fn( A_log=A_log, a=a, dt_bias=dt_bias, @@ -1392,7 +1450,7 @@ def _test_gdn_decode_klast_bf16_state_kernel( ref_o.float(), atol=atol_o, rtol=rtol_o, - msg=f"Output mismatch for gdn_decode_klast_bf16_state kernel (B={batch_size}, T={seq_len})", + msg=f"Output mismatch for BF16 state kernel (B={batch_size}, T={seq_len})", ) # Compare states: both in bf16 (kernel [B, HV, V, K], ref [B, HV, K, V]) @@ -1402,11 +1460,11 @@ def _test_gdn_decode_klast_bf16_state_kernel( ref_state_transposed.float(), atol=atol_kv, rtol=rtol_kv, - msg=f"State mismatch for gdn_decode_klast_bf16_state kernel (B={batch_size}, T={seq_len})", + msg=f"State mismatch for BF16 state kernel (B={batch_size}, T={seq_len})", ) print( - f"✓ gdn_decode_klast_bf16_state kernel test passed (batch={batch_size}, T={seq_len}, dtype={dtype}, h_state=bf16)" + f"✓ BF16 state kernel test passed (batch={batch_size}, T={seq_len}, dtype={dtype}, h_state=bf16)" ) @@ -1419,10 +1477,9 @@ def _test_gdn_decode_klast_bf16_state_kernel( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_gdn_decode_klast_bf16_state_kernel( +def test_gdn_decode_bf16_state_kernel( dtype: str, num_q_heads: int, num_k_heads: int, @@ -1436,7 +1493,7 @@ def test_gdn_decode_klast_bf16_state_kernel( seed: int = int(os.environ.get("SEED", "0")), ): scale_val = 1.0 / math.sqrt(head_size) if scale == "auto" else scale - _test_gdn_decode_klast_bf16_state_kernel( + _test_gdn_decode_bf16_state_kernel( dtype, batch_size, num_q_heads, @@ -1451,7 +1508,6 @@ def test_gdn_decode_klast_bf16_state_kernel( ) -@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) @pytest.mark.parametrize("batch_size", [1, 2, 4]) @pytest.mark.parametrize("head_size", [128]) @@ -1459,7 +1515,7 @@ def test_gdn_decode_klast_bf16_state_kernel( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) -def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( +def test_pretranspose_api_uses_gdn_decode_bf16_state( num_q_heads: int, num_k_heads: int, num_v_heads: int, @@ -1468,13 +1524,13 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( seq_len: int, seed: int = int(os.environ.get("SEED", "0")), ): - """Verify gated_delta_rule_decode_pretranspose dispatches to gdn_decode_klast_bf16_state when state is bf16 and T<=4, K=V=128. + """Verify gated_delta_rule_decode_pretranspose dispatches to BF16 state kernel when state is bf16 and K=V=128. - Calls the API with bf16 state and checks output/state match the direct gdn_decode_klast_bf16_state call. + Calls the API with bf16 state and checks output/state match the direct kernel call. """ _skip_if_not_sm90_or_later() - if not GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: - pytest.skip("gdn_decode_klast_bf16_state kernel not available") + if not GDN_DECODE_BF16_STATE_AVAILABLE: + pytest.skip("BF16 state kernel not available") random.seed(seed) torch.random.manual_seed(seed) @@ -1515,7 +1571,7 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( ) state_direct = state_api.clone() - # Via API (should dispatch to gdn_decode_klast_bf16_state) + # Via API (should dispatch to gdn_decode_bf16_state) out_api, state_api = gated_delta_rule_decode_pretranspose( q=q, k=k, @@ -1529,8 +1585,9 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( use_qk_l2norm=True, ) - # Direct improved kernel - out_direct = gdn_decode_klast_bf16_state( + # Direct improved kernel (T=1 uses gdn_decode_bf16_state, T>1 uses MTP variant) + kernel_fn = gdn_decode_bf16_state if seq_len == 1 else gdn_decode_bf16_state_mtp + out_direct = kernel_fn( A_log=A_log, a=a, dt_bias=dt_bias, @@ -1548,7 +1605,376 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( torch.testing.assert_close(out_api, out_direct, atol=1e-2, rtol=1e-2) torch.testing.assert_close(state_api, state_direct, atol=1e-2, rtol=1e-2) print( - f"✓ API gdn_decode_klast_bf16_state backend verified (batch={batch_size}, T={seq_len})" + f"✓ API gdn_decode_bf16_state backend verified (batch={batch_size}, T={seq_len})" + ) + + +# ============================================================================ +# Test BF16 state kernel (T=1) +# ============================================================================ + + +def _test_gdn_decode_bf16_state_t1_kernel( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + alpha: bool, + beta: bool, + seed: int | None = None, +): + """Test BF16 state kernel for T=1. + + Both kernel and reference use bf16 h state so the comparison is apples-to-apples. + """ + _skip_if_not_sm90_or_later() + + if not GDN_DECODE_BF16_STATE_AVAILABLE: + pytest.skip("BF16 state kernel not available") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=dtype_torch) + k = torch.randn(batch_size, 1, num_k_heads, head_size, dtype=dtype_torch) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=dtype_torch) + + input_state_kernel = torch.randn( + batch_size, num_sab_heads, head_size, head_size, dtype=torch.bfloat16 + ) + input_state_ref_bf16 = input_state_kernel.transpose(-2, -1).contiguous() + + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + a = ( + torch.randn(batch_size, 1, num_sab_heads, dtype=dtype_torch, device=device) + * 0.1 + ) + + if beta: + b_tensor = torch.randn( + batch_size, 1, num_sab_heads, dtype=dtype_torch, device=device + ) + else: + b_tensor = ( + torch.ones( + batch_size, 1, num_sab_heads, dtype=dtype_torch, device=device + ) + * 10.0 + ) + + our_state = input_state_kernel.clone() + our_o = gdn_decode_bf16_state( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b_tensor, + initial_state_source=our_state, + use_qk_l2norm_in_kernel=True, + scale=scale, + ) + + torch.cuda.synchronize() + + ref_state = input_state_ref_bf16.clone() + ref_o_t, ref_state = decode_delta_rule( + q[:, 0].float(), + k[:, 0].float(), + v[:, 0].float(), + ref_state, + A_log=A_log, + a=a[:, 0], + dt_bias=dt_bias, + b=b_tensor[:, 0], + scale_factor=scale, + softplus_beta=1.0, + softplus_threshold=20.0, + use_l2_norm=True, + state_dtype=torch.bfloat16, + ) + ref_o = ref_o_t.unsqueeze(1).to(dtype_torch) + + atol_o = 0.001 + rtol_o = 0.005 + # State tolerances slightly higher: BF16 state accumulation at large batch + # sizes can produce diffs up to ~0.016 (1 BF16 ULP at magnitude ~2) + atol_kv = 0.02 + rtol_kv = 0.01 + + torch.testing.assert_close( + our_o.float(), + ref_o.float(), + atol=atol_o, + rtol=rtol_o, + msg=f"Output mismatch for BF16 state kernel (B={batch_size})", + ) + + ref_state_transposed = ref_state.transpose(-2, -1).contiguous() + torch.testing.assert_close( + our_state.float(), + ref_state_transposed.float(), + atol=atol_kv, + rtol=rtol_kv, + msg=f"State mismatch for BF16 state kernel (B={batch_size})", + ) + + print(f" BF16 state T=1 PASS (batch={batch_size}, dtype={dtype})") + + +@pytest.mark.parametrize("beta", [True]) +@pytest.mark.parametrize("alpha", [True]) +@pytest.mark.parametrize("scale", ["auto"]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + "num_q_heads, num_k_heads, num_v_heads", + [(16, 16, 32)], +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_gdn_decode_bf16_state_t1_kernel( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + scale: float | str, + alpha: bool, + beta: bool, + seed: int = int(os.environ.get("SEED", "0")), +): + scale_val = 1.0 / math.sqrt(head_size) if scale == "auto" else scale + _test_gdn_decode_bf16_state_t1_kernel( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + scale_val, + alpha, + beta, + seed, + ) + + +# ============================================================================ +# Test BF16 state MTP kernel (T>=2) +# ============================================================================ + + +def _test_gdn_decode_bf16_state_mtp_kernel( + dtype: str, + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + seq_len: int, + scale: float, + cache_intermediate_states: bool, + seed: int | None = None, +): + """Test MTP BF16 state kernel for T>=2. + + Both kernel and reference use bf16 h state. + Tests cache_intermediate_states and disable_state_update=True. + """ + _skip_if_not_sm90_or_later() + + if not GDN_DECODE_BF16_STATE_AVAILABLE: + pytest.skip("BF16 state kernel not available") + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_sab_heads = num_v_heads + dtype_torch = getattr(torch, dtype) + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, seq_len, num_q_heads, head_size, dtype=dtype_torch) + k = torch.randn(batch_size, seq_len, num_k_heads, head_size, dtype=dtype_torch) + v = torch.randn(batch_size, seq_len, num_v_heads, head_size, dtype=dtype_torch) + + pool_size = batch_size + input_state_kernel = torch.randn( + pool_size, num_sab_heads, head_size, head_size, dtype=torch.bfloat16 + ) + input_state_ref_bf16 = input_state_kernel.transpose(-2, -1).contiguous() + + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + dt_bias = torch.randn(num_sab_heads, dtype=torch.float32, device=device) * 0.1 + a = ( + torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + * 0.1 + ) + b_tensor = torch.randn( + batch_size, seq_len, num_sab_heads, dtype=dtype_torch, device=device + ) + initial_state_indices = torch.arange( + batch_size, dtype=torch.int32, device=device + ) + + if cache_intermediate_states: + intermediate_states_buffer = torch.zeros( + pool_size, + seq_len, + num_sab_heads, + head_size, + head_size, + dtype=torch.bfloat16, + device=device, + ) + else: + intermediate_states_buffer = None + + # Test with disable_state_update=True (MTP verify mode) + our_state = input_state_kernel.clone() + our_o = gdn_decode_bf16_state_mtp( + A_log=A_log, + a=a, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q, + k=k, + v=v, + b=b_tensor, + initial_state_source=our_state, + initial_state_indices=initial_state_indices, + intermediate_states_buffer=intermediate_states_buffer, + disable_state_update=True, + use_qk_l2norm_in_kernel=True, + scale=scale, + ) + + torch.cuda.synchronize() + + # Reference: step through tokens with bf16 state + ref_state = input_state_ref_bf16.clone() + ref_outputs = [] + ref_intermediate_states = [] + + for t in range(seq_len): + ref_o_t, ref_state = decode_delta_rule( + q[:, t].float(), + k[:, t].float(), + v[:, t].float(), + ref_state, + A_log=A_log, + a=a[:, t], + dt_bias=dt_bias, + b=b_tensor[:, t], + scale_factor=scale, + softplus_beta=1.0, + softplus_threshold=20.0, + use_l2_norm=True, + state_dtype=torch.bfloat16, + ) + ref_outputs.append(ref_o_t) + if cache_intermediate_states: + ref_intermediate_states.append(ref_state.clone()) + + ref_o = torch.stack(ref_outputs, dim=1).to(dtype_torch) + + atol_o = 0.001 + rtol_o = 0.005 + + torch.testing.assert_close( + our_o.float(), + ref_o.float(), + atol=atol_o, + rtol=rtol_o, + msg=f"Output mismatch for MTP BF16 state kernel (B={batch_size}, T={seq_len})", + ) + + # With disable_state_update=True, initial state should be unchanged + torch.testing.assert_close( + our_state.float(), + input_state_kernel.float(), + atol=0, + rtol=0, + msg=f"State should be unchanged with disable_state_update=True (B={batch_size}, T={seq_len})", + ) + + # Check intermediate states buffer contents against reference + if cache_intermediate_states and intermediate_states_buffer is not None: + # intermediate_states_buffer: [pool_size, T, HV, V, K] (K-last layout, bf16) + # ref intermediate states: [B, HV, K, V] per step (K-major layout, bf16) + # Stack ref: [B, T, HV, K, V], transpose to [B, T, HV, V, K] for comparison + ref_inter = torch.stack(ref_intermediate_states, dim=1) # [B, T, HV, K, V] + ref_inter_transposed = ref_inter.transpose( + -2, -1 + ).contiguous() # [B, T, HV, V, K] + + atol_s = 0.02 + rtol_s = 0.01 + torch.testing.assert_close( + intermediate_states_buffer.float(), + ref_inter_transposed.float(), + atol=atol_s, + rtol=rtol_s, + msg=f"Intermediate states mismatch for MTP BF16 state kernel (B={batch_size}, T={seq_len})", + ) + + print( + f" BF16 state MTP PASS (batch={batch_size}, T={seq_len}, " + f"cache_intermediate={cache_intermediate_states})" + ) + + +@pytest.mark.parametrize("cache_intermediate_states", [True, False]) +@pytest.mark.parametrize("seq_len", [2, 4, 8]) +@pytest.mark.parametrize("scale", ["auto"]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize( + "num_q_heads, num_k_heads, num_v_heads", + [(16, 16, 32)], +) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_gdn_decode_bf16_state_mtp_kernel( + dtype: str, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + batch_size: int, + seq_len: int, + scale: float | str, + cache_intermediate_states: bool, + seed: int = int(os.environ.get("SEED", "0")), +): + scale_val = 1.0 / math.sqrt(head_size) if scale == "auto" else scale + _test_gdn_decode_bf16_state_mtp_kernel( + dtype, + batch_size, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + seq_len, + scale_val, + cache_intermediate_states, + seed, ) @@ -1622,10 +2048,10 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( seed=42, ) - print("\n=== Testing IMPROVED CuTe-DSL version (T=1,2,3,4) ===") - if GDN_DECODE_KLAST_BF16_STATE_AVAILABLE: + print("\n=== Testing BF16 state kernel (T=1,2,3,4) ===") + if GDN_DECODE_BF16_STATE_AVAILABLE: for t in [1, 2, 3, 4]: - _test_gdn_decode_klast_bf16_state_kernel( + _test_gdn_decode_bf16_state_kernel( dtype="bfloat16", batch_size=4, num_q_heads=16, @@ -1639,7 +2065,7 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( seed=42, ) else: - print("⚠ gdn_decode_klast_bf16_state kernel not available, skipping...") + print("⚠ BF16 state kernel not available, skipping...") print("\n✅ All smoke tests passed!") print("\nTo run full test suite:") @@ -1653,6 +2079,6 @@ def test_pretranspose_api_uses_gdn_decode_klast_bf16_state( " MTP (VERIFY): pytest test_decode_delta_rule.py::test_verify_kernel_mtp -v" ) print( - " gdn_decode_klast_bf16_state: pytest test_decode_delta_rule.py::test_gdn_decode_klast_bf16_state_kernel -v" + " gdn_decode_bf16_state: pytest test_decode_delta_rule.py::test_gdn_decode_bf16_state_kernel -v" ) print(" ALL: pytest test_decode_delta_rule.py -v")