diff --git a/benchmarks/bench_gdn_decode.py b/benchmarks/bench_gdn_decode.py index 49c6147912..9ec10b8fa5 100644 --- a/benchmarks/bench_gdn_decode.py +++ b/benchmarks/bench_gdn_decode.py @@ -14,10 +14,33 @@ limitations under the License. """ +""" +GDN Decode Benchmark + +This benchmark supports: +1. All layouts comparison (default for decode): FlashInfer/Triton x pretranspose/nontranspose +2. Single layout comparison: FlashInfer (CuTe DSL) vs Triton kernel (--compare) +3. MTP benchmark (--version mtp) + +Usage: + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 + + # Single layout comparison: FlashInfer vs Triton + python benchmarks/bench_gdn_decode.py --compare --batch-size 1 4 8 16 32 64 128 256 512 + + # MTP benchmark (FlashInfer only) + python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 32 128 + + # MTP comparison: FlashInfer vs Triton + python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 + + # Use Qwen3-Next preset + python benchmarks/bench_gdn_decode.py --preset qwen3-next --batch-size 1 32 128 512 +""" + import argparse -import json import numpy as np -import os import torch from flashinfer.gdn_decode import ( @@ -25,54 +48,12 @@ gated_delta_rule_decode, gated_delta_rule_mtp, ) +from flashinfer.testing import bench_gpu_time -def parse_trace_file(trace_file: str, version: str = "pretranspose"): - """ - Parse a torch profiler trace file and extract GDN kernel timings. - - Args: - trace_file: Path to the trace JSON file - version: 'pretranspose', 'nontranspose', or 'mtp' - - Returns: - dict with 'kernel_times' list (in microseconds) - """ - with open(trace_file, "r") as f: - trace_data = json.load(f) - - # GDN kernel patterns (CuTe DSL kernels) - if version == "pretranspose": - kernel_patterns = [ - "gdn_decode_kernel_small_batch_pretranspose", # Small batch kernel - "gdn_decode_kernel_big_batch_pretranspose", # Big batch kernel - ] - elif version == "nontranspose": - kernel_patterns = [ - "gdn_decode_kernel_small_batch_nontranspose", # Small batch kernel - "gdn_decode_kernel_big_batch_nontranspose", # Big batch kernel - ] - elif version == "mtp": - kernel_patterns = [ - "gdn_verify_kernel_mtp", # MTP kernel - ] - else: - raise ValueError(f"Unknown version: {version}") - - kernel_durations = [] - - for event in trace_data.get("traceEvents", []): - if event.get("cat") != "kernel": - continue - - name = event.get("name", "") - dur = event.get("dur", 0) # duration in microseconds - - # Check if it's a GDN kernel - if any(pattern in name for pattern in kernel_patterns): - kernel_durations.append(dur) - - return {"kernel_times": kernel_durations} +# ============================================================================ +# Utility Functions +# ============================================================================ def gdn_decode_flops( @@ -191,6 +172,847 @@ def gdn_decode_bytes( return total_bytes +# ============================================================================ +# Triton Kernels for comparison benchmarks +# ============================================================================ + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + +if TRITON_AVAILABLE: + + @triton.jit + def fused_sigmoid_gating_delta_rule_kernel( + # Pointers to matrices + Q, + K, + V, + O, + H, # Hidden state [B, HV, K, V] + A_LOG, # Log decay [HV] + A, # Input-dependent decay [B, HV] + DT_BIAS, # Decay bias [HV] + B_GATE, # Update gate [B, HV] + # Strides + stride_qb, + stride_qh, + stride_qk, + stride_kb, + stride_kh, + stride_kk, + stride_vb, + stride_vh, + stride_vv, + stride_ob, + stride_oh, + stride_ov, + stride_hb, + stride_hh, + stride_hk, + stride_hv, + # Parameters + softplus_beta: tl.constexpr, + softplus_threshold: tl.constexpr, + scale: tl.constexpr, + use_qk_l2norm: tl.constexpr, + B: tl.constexpr, + HV: tl.constexpr, + H_Q: tl.constexpr, + H_K: tl.constexpr, + K_DIM: tl.constexpr, + V_DIM: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + ): + """ + Triton kernel for fused sigmoid gating delta rule update. + + Follows SGLang's implementation: + 1. g = -exp(A_log) * softplus(a + dt_bias) + 2. beta = sigmoid(b) + 3. h *= exp(g) + 4. v_new = v - k @ h + 5. v_new *= beta + 6. h += outer(k, v_new) + 7. o = q @ h + """ + # Block indices + i_bh = tl.program_id(0) + i_k = tl.program_id(1) + i_v = tl.program_id(2) + + i_b = i_bh // HV + i_hv = i_bh % HV + + # GVA head mapping (num_v_heads > num_q_heads) + h_ratio_q = HV // H_Q + h_ratio_k = HV // H_K + i_hq = i_hv // h_ratio_q + i_hk = i_hv // h_ratio_k + + # Load A_log and dt_bias for this head + b_A_log = tl.load(A_LOG + i_hv).to(tl.float32) + b_dt_bias = tl.load(DT_BIAS + i_hv).to(tl.float32) + + # Load a (input-dependent decay) for this batch and head + b_a = tl.load(A + i_b * HV + i_hv).to(tl.float32) + + # Load b (update gate) for this batch and head + b_b = tl.load(B_GATE + i_b * HV + i_hv).to(tl.float32) + + # Compute softplus: softplus(x) = (1/beta) * log(1 + exp(beta*x)) + x = b_a + b_dt_bias + beta_x = softplus_beta * x + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + b_g = -tl.exp(b_A_log) * softplus_x + + # Compute beta = sigmoid(b) + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + # Block offsets + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + # Load q, k, v + p_q = Q + i_b * stride_qb + i_hq * stride_qh + o_k * stride_qk + p_k = K + i_b * stride_kb + i_hk * stride_kh + o_k * stride_kk + p_v = V + i_b * stride_vb + i_hv * stride_vh + o_v * stride_vv + + b_q = tl.load(p_q, mask=o_k < K_DIM, other=0.0).to(tl.float32) + b_k = tl.load(p_k, mask=o_k < K_DIM, other=0.0).to(tl.float32) + b_v = tl.load(p_v, mask=o_v < V_DIM, other=0.0).to(tl.float32) + + # Apply L2 normalization (if enabled) + if use_qk_l2norm: + # Compute L2 norm across K dimension (need reduction across blocks) + # For simplicity, assume single K block + q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) + k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) + b_q = b_q / q_norm + b_k = b_k / k_norm + + # Apply scale to q + b_q = b_q * scale + + # Load hidden state h[K, V] from state[B, HV, K, V] + p_h = ( + H + + i_b * stride_hb + + i_hv * stride_hh + + o_k[:, None] * stride_hk + + o_v[None, :] * stride_hv + ) + b_h = tl.load( + p_h, mask=(o_k[:, None] < K_DIM) & (o_v[None, :] < V_DIM), other=0.0 + ).to(tl.float32) + + # Step 1: Apply decay to hidden state: h *= exp(g) + b_h = b_h * tl.exp(b_g) + + # Step 2: Delta rule: v -= sum(h * k, dim=0) = k @ h + # b_h is [BK, BV], b_k is [BK] + # We need to compute k @ h = sum(k[:, None] * h, dim=0) + b_v = b_v - tl.sum(b_h * b_k[:, None], 0) + + # Step 3: Apply beta gating: v *= beta + b_v = b_v * b_beta + + # Step 4: Update hidden state: h += outer(k, v) = k[:, None] * v[None, :] + b_h = b_h + b_k[:, None] * b_v[None, :] + + # Step 5: Compute output: o = q @ h = sum(q[:, None] * h, dim=0) + b_o = tl.sum(b_h * b_q[:, None], 0) + + # Store output + p_o = O + i_b * stride_ob + i_hv * stride_oh + o_v * stride_ov + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=o_v < V_DIM) + + # Store updated hidden state + tl.store( + p_h, + b_h.to(p_h.dtype.element_ty), + mask=(o_k[:, None] < K_DIM) & (o_v[None, :] < V_DIM), + ) + + @triton.jit + def fused_sigmoid_gating_delta_rule_mtp_kernel( + # Pointers to matrices + Q, # [B, T, H_Q, K] + K, # [B, T, H_K, K] + V, # [B, T, HV, V] + O, # [B, T, HV, V] + H, # Hidden state [pool_size, HV, V, K] (K-last layout) + INTERMEDIATE, # Intermediate states [pool_size, T, HV, V, K] + H0_INDICES, # [B] + A_LOG, # Log decay [HV] + A, # Input-dependent decay [B, T, HV] + DT_BIAS, # Decay bias [HV] + B_GATE, # Update gate [B, T, HV] + # Strides for Q, K, V, O [B, T, H, dim] + stride_qb, + stride_qt, + stride_qh, + stride_qk, + stride_kb, + stride_kt, + stride_kh, + stride_kk, + stride_vb, + stride_vt, + stride_vh, + stride_vv, + stride_ob, + stride_ot, + stride_oh, + stride_ov, + # Strides for hidden state [pool_size, HV, V, K] + stride_hp, + stride_hh, + stride_hv, + stride_hk, + # Strides for intermediate states [pool_size, T, HV, V, K] + stride_ip, + stride_it, + stride_ih, + stride_iv, + stride_ik, + # Strides for A [B, T, HV] + stride_ab, + stride_at, + stride_ah, + # Parameters + softplus_beta: tl.constexpr, + softplus_threshold: tl.constexpr, + scale: tl.constexpr, + use_qk_l2norm: tl.constexpr, + disable_state_update: tl.constexpr, + cache_intermediate_states: tl.constexpr, + B: tl.constexpr, + T: tl.constexpr, + HV: tl.constexpr, + H_Q: tl.constexpr, + H_K: tl.constexpr, + K_DIM: tl.constexpr, + V_DIM: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + ): + """ + Triton kernel for MTP (Multiple Token Processing) delta rule update. + Processes T tokens sequentially, updating state after each token. + + Note: The delta rule operations are fundamentally GEMV (matrix-vector) and + rank-1 updates, which don't directly benefit from tensor cores. Tensor cores + are optimized for GEMM (matrix-matrix). To use tensor cores, we would need + to batch across multiple tokens/heads to form proper GEMM operations. + """ + # Block indices + i_bh = tl.program_id(0) + i_k = tl.program_id(1) + i_v = tl.program_id(2) + + i_b = i_bh // HV + i_hv = i_bh % HV + + # GVA head mapping + h_ratio_q = HV // H_Q + h_ratio_k = HV // H_K + i_hq = i_hv // h_ratio_q + i_hk = i_hv // h_ratio_k + + # Load initial state index for this batch + i_pool = tl.load(H0_INDICES + i_b) + + # Load A_log and dt_bias for this head + b_A_log = tl.load(A_LOG + i_hv).to(tl.float32) + b_dt_bias = tl.load(DT_BIAS + i_hv).to(tl.float32) + + # Block offsets + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + # Load initial hidden state h[V, K] from state[pool, HV, V, K] + p_h = ( + H + + i_pool * stride_hp + + i_hv * stride_hh + + o_v[:, None] * stride_hv + + o_k[None, :] * stride_hk + ) + b_h = tl.load( + p_h, mask=(o_v[:, None] < V_DIM) & (o_k[None, :] < K_DIM), other=0.0 + ).to(tl.float32) # [BV, BK] + + # Process each token + for t in range(T): + # Load a for this batch, time, head + b_a = tl.load(A + i_b * stride_ab + t * stride_at + i_hv * stride_ah).to( + tl.float32 + ) + b_b = tl.load( + B_GATE + i_b * stride_ab + t * stride_at + i_hv * stride_ah + ).to(tl.float32) + + # Compute softplus and decay + x = b_a + b_dt_bias + beta_x = softplus_beta * x + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + # Load q, k, v for this timestep + p_q = ( + Q + i_b * stride_qb + t * stride_qt + i_hq * stride_qh + o_k * stride_qk + ) + p_k = ( + K + i_b * stride_kb + t * stride_kt + i_hk * stride_kh + o_k * stride_kk + ) + p_v = ( + V + i_b * stride_vb + t * stride_vt + i_hv * stride_vh + o_v * stride_vv + ) + + b_q = tl.load(p_q, mask=o_k < K_DIM, other=0.0).to(tl.float32) + b_k = tl.load(p_k, mask=o_k < K_DIM, other=0.0).to(tl.float32) + b_v = tl.load(p_v, mask=o_v < V_DIM, other=0.0).to(tl.float32) + + # Apply L2 normalization + if use_qk_l2norm: + q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) + k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) + b_q = b_q / q_norm + b_k = b_k / k_norm + + b_q = b_q * scale + + # Step 1: Apply decay: h *= exp(g) + b_h = b_h * tl.exp(b_g) + + # Step 2: Delta rule: v -= h @ k (h is [BV, BK], k is [BK]) + # This is GEMV, which doesn't directly use tensor cores efficiently + # h @ k = sum(h * k[None, :], axis=1) -> [BV] + b_v = b_v - tl.sum(b_h * b_k[None, :], 1) + + # Step 3: Apply beta gating + b_v = b_v * b_beta + + # Step 4: Update state: h += outer(v, k) = v[:, None] * k[None, :] + # This is a rank-1 update + b_h = b_h + b_v[:, None] * b_k[None, :] + + # Step 5: Compute output: o = h @ q = sum(h * q[None, :], axis=1) -> [BV] + # This is also GEMV + b_o = tl.sum(b_h * b_q[None, :], 1) + + # Store output for this timestep + p_o = ( + O + i_b * stride_ob + t * stride_ot + i_hv * stride_oh + o_v * stride_ov + ) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=o_v < V_DIM) + + # Cache intermediate state if needed + if cache_intermediate_states: + p_inter = ( + INTERMEDIATE + + i_pool * stride_ip + + t * stride_it + + i_hv * stride_ih + + o_v[:, None] * stride_iv + + o_k[None, :] * stride_ik + ) + tl.store( + p_inter, + b_h.to(p_inter.dtype.element_ty), + mask=(o_v[:, None] < V_DIM) & (o_k[None, :] < K_DIM), + ) + + # Store final state if state update is enabled + if not disable_state_update: + tl.store( + p_h, + b_h.to(p_h.dtype.element_ty), + mask=(o_v[:, None] < V_DIM) & (o_k[None, :] < K_DIM), + ) + + def triton_gdn_decode( + q: torch.Tensor, # [B, 1, H_Q, K] + k: torch.Tensor, # [B, 1, H_K, K] + v: torch.Tensor, # [B, 1, HV, V] + state: torch.Tensor, # [B, HV, K, V] + A_log: torch.Tensor, # [HV] + a: torch.Tensor, # [B, 1, HV] + dt_bias: torch.Tensor, # [HV] + b: torch.Tensor, # [B, 1, HV] + scale: float, + output: torch.Tensor, # [B, 1, HV, V] + use_qk_l2norm: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + ): + """ + Triton-based GDN decode matching SGLang's implementation. + """ + B, T, H_Q, K_DIM = q.shape + _, _, H_K, _ = k.shape + _, _, HV, V_DIM = v.shape + + assert T == 1, "Triton kernel only supports decode (T=1)" + + # Reshape inputs for kernel + q_flat = q.squeeze(1) # [B, H_Q, K] + k_flat = k.squeeze(1) # [B, H_K, K] + v_flat = v.squeeze(1) # [B, HV, V] + a_flat = a.squeeze(1) # [B, HV] + b_flat = b.squeeze(1) # [B, HV] + o_flat = output.squeeze(1) # [B, HV, V] + + # Block sizes + BK = triton.next_power_of_2(K_DIM) + BV = triton.next_power_of_2(V_DIM) + + # Limit block sizes (BV smaller to allow more V blocks) + BV = min(BV, 32) + + # Number of blocks + NK = triton.cdiv(K_DIM, BK) + NV = triton.cdiv(V_DIM, BV) + + assert NK == 1, f"Multi-block K not supported: NK={NK}" + + # Launch kernel + grid = (B * HV, NK, NV) + + fused_sigmoid_gating_delta_rule_kernel[grid]( + q_flat, + k_flat, + v_flat, + o_flat, + state, + A_log, + a_flat, + dt_bias, + b_flat, + # Strides for q [B, H_Q, K] + q_flat.stride(0), + q_flat.stride(1), + q_flat.stride(2), + # Strides for k [B, H_K, K] + k_flat.stride(0), + k_flat.stride(1), + k_flat.stride(2), + # Strides for v [B, HV, V] + v_flat.stride(0), + v_flat.stride(1), + v_flat.stride(2), + # Strides for o [B, HV, V] + o_flat.stride(0), + o_flat.stride(1), + o_flat.stride(2), + # Strides for h [B, HV, K, V] + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # Parameters + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + use_qk_l2norm=use_qk_l2norm, + B=B, + HV=HV, + H_Q=H_Q, + H_K=H_K, + K_DIM=K_DIM, + V_DIM=V_DIM, + BK=BK, + BV=BV, + ) + + return output, state + + @triton.jit + def fused_sigmoid_gating_delta_rule_kernel_pretranspose( + # Pointers to matrices + Q, + K, + V, + O, + H, # Hidden state [B, HV, V, K] - V-major (pretranspose) layout + A_LOG, # Log decay [HV] + A, # Input-dependent decay [B, HV] + DT_BIAS, # Decay bias [HV] + B_GATE, # Update gate [B, HV] + # Strides + stride_qb, + stride_qh, + stride_qk, + stride_kb, + stride_kh, + stride_kk, + stride_vb, + stride_vh, + stride_vv, + stride_ob, + stride_oh, + stride_ov, + stride_hb, + stride_hh, + stride_hv, # V dimension stride + stride_hk, # K dimension stride + # Parameters + softplus_beta: tl.constexpr, + softplus_threshold: tl.constexpr, + scale: tl.constexpr, + use_qk_l2norm: tl.constexpr, + B: tl.constexpr, + HV: tl.constexpr, + H_Q: tl.constexpr, + H_K: tl.constexpr, + K_DIM: tl.constexpr, + V_DIM: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + ): + """ + Triton kernel for pretranspose layout [B, HV, V, K]. + + Key difference from nontranspose: + - State layout: [B, HV, V, K] instead of [B, HV, K, V] + - h is [BV, BK] instead of [BK, BV] + - h @ k = sum(h * k[None, :], axis=1) -> [BV] + - h += outer(v, k) = v[:, None] * k[None, :] + - o = h @ q = sum(h * q[None, :], axis=1) -> [BV] + """ + # Block indices + i_bh = tl.program_id(0) + i_k = tl.program_id(1) + i_v = tl.program_id(2) + + i_b = i_bh // HV + i_hv = i_bh % HV + + # GVA head mapping (num_v_heads > num_q_heads) + h_ratio_q = HV // H_Q + h_ratio_k = HV // H_K + i_hq = i_hv // h_ratio_q + i_hk = i_hv // h_ratio_k + + # Load A_log and dt_bias for this head + b_A_log = tl.load(A_LOG + i_hv).to(tl.float32) + b_dt_bias = tl.load(DT_BIAS + i_hv).to(tl.float32) + + # Load a (input-dependent decay) for this batch and head + b_a = tl.load(A + i_b * HV + i_hv).to(tl.float32) + + # Load b (update gate) for this batch and head + b_b = tl.load(B_GATE + i_b * HV + i_hv).to(tl.float32) + + # Compute softplus: softplus(x) = (1/beta) * log(1 + exp(beta*x)) + x = b_a + b_dt_bias + beta_x = softplus_beta * x + softplus_x = tl.where( + beta_x <= softplus_threshold, + (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + x, + ) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + b_g = -tl.exp(b_A_log) * softplus_x + + # Compute beta = sigmoid(b) + b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + + # Block offsets + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + # Load q, k, v + p_q = Q + i_b * stride_qb + i_hq * stride_qh + o_k * stride_qk + p_k = K + i_b * stride_kb + i_hk * stride_kh + o_k * stride_kk + p_v = V + i_b * stride_vb + i_hv * stride_vh + o_v * stride_vv + + b_q = tl.load(p_q, mask=o_k < K_DIM, other=0.0).to(tl.float32) + b_k = tl.load(p_k, mask=o_k < K_DIM, other=0.0).to(tl.float32) + b_v = tl.load(p_v, mask=o_v < V_DIM, other=0.0).to(tl.float32) + + # Apply L2 normalization (if enabled) + if use_qk_l2norm: + q_norm = tl.sqrt(tl.sum(b_q * b_q) + 1e-8) + k_norm = tl.sqrt(tl.sum(b_k * b_k) + 1e-8) + b_q = b_q / q_norm + b_k = b_k / k_norm + + # Apply scale to q + b_q = b_q * scale + + # Load hidden state h[V, K] from state[B, HV, V, K] - pretranspose layout + p_h = ( + H + + i_b * stride_hb + + i_hv * stride_hh + + o_v[:, None] * stride_hv + + o_k[None, :] * stride_hk + ) + b_h = tl.load( + p_h, mask=(o_v[:, None] < V_DIM) & (o_k[None, :] < K_DIM), other=0.0 + ).to(tl.float32) # [BV, BK] + + # Step 1: Apply decay to hidden state: h *= exp(g) + b_h = b_h * tl.exp(b_g) + + # Step 2: Delta rule: v -= h @ k = sum(h * k[None, :], axis=1) + # b_h is [BV, BK], b_k is [BK] + b_v = b_v - tl.sum(b_h * b_k[None, :], 1) + + # Step 3: Apply beta gating: v *= beta + b_v = b_v * b_beta + + # Step 4: Update hidden state: h += outer(v, k) = v[:, None] * k[None, :] + b_h = b_h + b_v[:, None] * b_k[None, :] + + # Step 5: Compute output: o = h @ q = sum(h * q[None, :], axis=1) + b_o = tl.sum(b_h * b_q[None, :], 1) + + # Store output + p_o = O + i_b * stride_ob + i_hv * stride_oh + o_v * stride_ov + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=o_v < V_DIM) + + # Store updated hidden state + tl.store( + p_h, + b_h.to(p_h.dtype.element_ty), + mask=(o_v[:, None] < V_DIM) & (o_k[None, :] < K_DIM), + ) + + def triton_gdn_decode_pretranspose( + q: torch.Tensor, # [B, 1, H_Q, K] + k: torch.Tensor, # [B, 1, H_K, K] + v: torch.Tensor, # [B, 1, HV, V] + state: torch.Tensor, # [B, HV, V, K] - pretranspose layout + A_log: torch.Tensor, # [HV] + a: torch.Tensor, # [B, 1, HV] + dt_bias: torch.Tensor, # [HV] + b: torch.Tensor, # [B, 1, HV] + scale: float, + output: torch.Tensor, # [B, 1, HV, V] + use_qk_l2norm: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + ): + """ + Triton-based GDN decode for pretranspose layout [B, HV, V, K]. + """ + B, T, H_Q, K_DIM = q.shape + _, _, H_K, _ = k.shape + _, _, HV, V_DIM = v.shape + + assert T == 1, "Triton kernel only supports decode (T=1)" + + # Reshape inputs for kernel + q_flat = q.squeeze(1) # [B, H_Q, K] + k_flat = k.squeeze(1) # [B, H_K, K] + v_flat = v.squeeze(1) # [B, HV, V] + a_flat = a.squeeze(1) # [B, HV] + b_flat = b.squeeze(1) # [B, HV] + o_flat = output.squeeze(1) # [B, HV, V] + + # Block sizes + BK = triton.next_power_of_2(K_DIM) + BV = triton.next_power_of_2(V_DIM) + + # Limit block sizes (BV smaller to allow more V blocks) + BV = min(BV, 32) + + # Number of blocks + NK = triton.cdiv(K_DIM, BK) + NV = triton.cdiv(V_DIM, BV) + + assert NK == 1, f"Multi-block K not supported: NK={NK}" + + # Launch kernel + grid = (B * HV, NK, NV) + + fused_sigmoid_gating_delta_rule_kernel_pretranspose[grid]( + q_flat, + k_flat, + v_flat, + o_flat, + state, + A_log, + a_flat, + dt_bias, + b_flat, + # Strides for q [B, H_Q, K] + q_flat.stride(0), + q_flat.stride(1), + q_flat.stride(2), + # Strides for k [B, H_K, K] + k_flat.stride(0), + k_flat.stride(1), + k_flat.stride(2), + # Strides for v [B, HV, V] + v_flat.stride(0), + v_flat.stride(1), + v_flat.stride(2), + # Strides for o [B, HV, V] + o_flat.stride(0), + o_flat.stride(1), + o_flat.stride(2), + # Strides for h [B, HV, V, K] - pretranspose layout + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # Parameters + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + use_qk_l2norm=use_qk_l2norm, + B=B, + HV=HV, + H_Q=H_Q, + H_K=H_K, + K_DIM=K_DIM, + V_DIM=V_DIM, + BK=BK, + BV=BV, + ) + + return output, state + + def triton_gdn_mtp( + q: torch.Tensor, # [B, T, H_Q, K] + k: torch.Tensor, # [B, T, H_K, K] + v: torch.Tensor, # [B, T, HV, V] + initial_state: torch.Tensor, # [pool_size, HV, V, K] + initial_state_indices: torch.Tensor, # [B] + A_log: torch.Tensor, # [HV] + a: torch.Tensor, # [B, T, HV] + dt_bias: torch.Tensor, # [HV] + b: torch.Tensor, # [B, T, HV] + scale: float, + output: torch.Tensor, # [B, T, HV, V] + intermediate_states_buffer: torch.Tensor = None, # [pool_size, T, HV, V, K] + disable_state_update: bool = True, + use_qk_l2norm: bool = True, + softplus_beta: float = 1.0, + softplus_threshold: float = 20.0, + ): + """ + Triton-based GDN MTP matching SGLang's implementation. + """ + B, T, H_Q, K_DIM = q.shape + _, _, H_K, _ = k.shape + _, _, HV, V_DIM = v.shape + + # Block sizes (BV smaller to allow more V blocks) + BK = triton.next_power_of_2(K_DIM) + BV = triton.next_power_of_2(V_DIM) + BV = min(BV, 32) + + NK = triton.cdiv(K_DIM, BK) + NV = triton.cdiv(V_DIM, BV) + + assert NK == 1, f"Multi-block K not supported: NK={NK}" + + cache_intermediate_states = intermediate_states_buffer is not None + if cache_intermediate_states: + intermediate = intermediate_states_buffer + else: + intermediate = torch.zeros( + 1, 1, 1, 1, 1, dtype=torch.float32, device=q.device + ) + + # Launch kernel + grid = (B * HV, NK, NV) + + fused_sigmoid_gating_delta_rule_mtp_kernel[grid]( + q, + k, + v, + output, + initial_state, + intermediate, + initial_state_indices, + A_log, + a, + dt_bias, + b, + # Q strides + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + # K strides + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + # V strides + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + # O strides + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + # H strides [pool_size, HV, V, K] + initial_state.stride(0), + initial_state.stride(1), + initial_state.stride(2), + initial_state.stride(3), + # Intermediate strides [pool_size, T, HV, V, K] + intermediate.stride(0), + intermediate.stride(1), + intermediate.stride(2) if cache_intermediate_states else 0, + intermediate.stride(3) if cache_intermediate_states else 0, + intermediate.stride(4) if cache_intermediate_states else 0, + # A strides [B, T, HV] + a.stride(0), + a.stride(1), + a.stride(2), + # Parameters + softplus_beta=softplus_beta, + softplus_threshold=softplus_threshold, + scale=scale, + use_qk_l2norm=use_qk_l2norm, + disable_state_update=disable_state_update, + cache_intermediate_states=cache_intermediate_states, + B=B, + T=T, + HV=HV, + H_Q=H_Q, + H_K=H_K, + K_DIM=K_DIM, + V_DIM=V_DIM, + BK=BK, + BV=BV, + ) + + return output, initial_state + + +# ============================================================================ +# FlashInfer-only Benchmark Functions +# ============================================================================ + + def bench_gdn_decode( batch_size: int, num_q_heads: int, @@ -198,14 +1020,14 @@ def bench_gdn_decode( num_v_heads: int, head_size: int, dtype: torch.dtype, - version: str = "pretranspose", + version: str = "nontranspose", use_alpha: bool = True, use_beta: bool = True, use_qk_l2norm: bool = True, warmup_iters: int = 10, bench_iters: int = 100, ): - """Benchmark GDN decode kernel using torch.profiler. + """Benchmark GDN decode kernel using bench_gpu_time with CUPTI. Args: version: 'pretranspose' or 'nontranspose' @@ -262,43 +1084,18 @@ def bench_gdn_decode( else: raise ValueError(f"Unknown version: {version}") - # Warmup - for _ in range(warmup_iters): - _, _ = decode_func( + # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) + kernel_times_ms = bench_gpu_time( + lambda: decode_func( q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm - ) - torch.cuda.synchronize() - - # Benchmark with torch.profiler - trace_dir = "gdn_decode_traces" - os.makedirs(trace_dir, exist_ok=True) - trace_file = os.path.join(trace_dir, f"trace_{version}_B{batch_size}.json") - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - ) as profiler: - for i in range(bench_iters): - with torch.profiler.record_function(f"gdn_decode_iter{i}"): - _, _ = decode_func( - q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm - ) - - profiler.export_chrome_trace(trace_file) - - # Parse trace file for kernel-level timing - trace_results = parse_trace_file(trace_file, version=version) - kernel_times = trace_results["kernel_times"] # in microseconds - - # Calculate statistics from kernel trace - kernel_median_us = np.median(kernel_times) if kernel_times else 0 - kernel_mean_us = np.mean(kernel_times) if kernel_times else 0 - kernel_std_us = np.std(kernel_times) if kernel_times else 0 + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) - # Calculate metrics (convert us to ms for FLOPS calculation) + # Calculate metrics + kernel_median_ms = np.median(kernel_times_ms) flops = gdn_decode_flops( batch_size, num_q_heads, num_k_heads, num_v_heads, head_size ) @@ -313,7 +1110,6 @@ def bench_gdn_decode( disable_state_update=False, # Decode mode: state is read + written ) - kernel_median_ms = kernel_median_us / 1000 kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 kernel_tb_per_sec = ( bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 @@ -321,17 +1117,9 @@ def bench_gdn_decode( return { "batch_size": batch_size, - "num_q_heads": num_q_heads, - "num_k_heads": num_k_heads, - "num_v_heads": num_v_heads, - "head_size": head_size, - "dtype": str(dtype).replace("torch.", ""), - "kernel_median_us": kernel_median_us, - "kernel_mean_us": kernel_mean_us, - "kernel_std_us": kernel_std_us, + "kernel_median_us": kernel_median_ms * 1000, "kernel_tflops": kernel_tflops, "kernel_tb_per_sec": kernel_tb_per_sec, - "trace_file": trace_file, } @@ -350,7 +1138,7 @@ def bench_gdn_mtp( warmup_iters: int = 10, bench_iters: int = 100, ): - """Benchmark GDN MTP kernel using torch.profiler.""" + """Benchmark GDN MTP kernel using bench_gpu_time with CUPTI.""" num_o_heads = max(num_q_heads, num_v_heads) num_sab_heads = num_o_heads @@ -408,9 +1196,9 @@ def bench_gdn_mtp( # Scale factor scale = 1.0 / (head_size**0.5) - # Warmup - for _ in range(warmup_iters): - _, _ = gated_delta_rule_mtp( + # Benchmark with bench_gpu_time (CUPTI for accurate kernel timing) + kernel_times_ms = bench_gpu_time( + lambda: gated_delta_rule_mtp( q, k, v, @@ -425,52 +1213,14 @@ def bench_gdn_mtp( intermediate_states_buffer, disable_state_update=True, use_qk_l2norm=use_qk_l2norm, - ) - torch.cuda.synchronize() - - # Benchmark with torch.profiler - trace_dir = "gdn_decode_traces" - os.makedirs(trace_dir, exist_ok=True) - trace_file = os.path.join(trace_dir, f"trace_mtp_B{batch_size}_T{seq_len}.json") - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - ) as profiler: - for i in range(bench_iters): - with torch.profiler.record_function(f"gdn_mtp_iter{i}"): - _, _ = gated_delta_rule_mtp( - q, - k, - v, - initial_state, - initial_state_indices, - A_log, - a, - dt_bias, - b, - scale, - output, - intermediate_states_buffer, - disable_state_update=True, - use_qk_l2norm=use_qk_l2norm, - ) - - profiler.export_chrome_trace(trace_file) - - # Parse trace file for kernel-level timing - trace_results = parse_trace_file(trace_file, version="mtp") - kernel_times = trace_results["kernel_times"] # in microseconds - - # Calculate statistics from kernel trace - kernel_median_us = np.median(kernel_times) if kernel_times else 0 - kernel_mean_us = np.mean(kernel_times) if kernel_times else 0 - kernel_std_us = np.std(kernel_times) if kernel_times else 0 + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) # Calculate metrics + kernel_median_ms = np.median(kernel_times_ms) flops = gdn_decode_flops( batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len ) @@ -485,7 +1235,6 @@ def bench_gdn_mtp( disable_state_update=True, # MTP mode: state is not written back ) - kernel_median_ms = kernel_median_us / 1000 kernel_tflops = flops / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 kernel_tb_per_sec = ( bytes_accessed / kernel_median_ms / 1e9 if kernel_median_ms > 0 else 0 @@ -494,111 +1243,852 @@ def bench_gdn_mtp( return { "batch_size": batch_size, "seq_len": seq_len, - "num_q_heads": num_q_heads, - "num_k_heads": num_k_heads, - "num_v_heads": num_v_heads, - "head_size": head_size, - "dtype": str(dtype).replace("torch.", ""), - "kernel_median_us": kernel_median_us, - "kernel_mean_us": kernel_mean_us, - "kernel_std_us": kernel_std_us, + "kernel_median_us": kernel_median_ms * 1000, "kernel_tflops": kernel_tflops, "kernel_tb_per_sec": kernel_tb_per_sec, - "trace_file": trace_file, } -def main(): - parser = argparse.ArgumentParser(description="Benchmark GDN Decode Kernel") - parser.add_argument( - "--batch-size", - type=int, - nargs="+", - default=[1, 4, 8, 16, 32, 64, 128, 256, 512], - help="Batch sizes to benchmark (number of concurrent decode requests)", +# ============================================================================ +# Comparison Benchmark Functions (FlashInfer vs Triton) +# ============================================================================ + + +def bench_comparison( + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + warmup_iters: int = 10, + bench_iters: int = 100, +): + """Benchmark both FlashInfer and Triton implementations.""" + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. Install with: pip install triton") + + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs (T=1 for decode) + T = 1 + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # ========== FlashInfer Benchmark ========== + # State for FlashInfer (K-major layout) [B, HV, K, V] + state_fi = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", ) - parser.add_argument("--num-q-heads", type=int, default=16) - parser.add_argument("--num-k-heads", type=int, default=16) - parser.add_argument("--num-v-heads", type=int, default=32) - parser.add_argument("--head-size", type=int, default=128) - parser.add_argument( - "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16" + output_fi = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" ) - parser.add_argument( - "--preset", - type=str, - choices=["qwen3-next", "custom"], - default="custom", - help="Use preset config. qwen3-next: q=k=16, v=32, d=128", + + flashinfer_times = bench_gpu_time( + lambda: gated_delta_rule_decode( + q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, ) - parser.add_argument( - "--no-qk-l2norm", - action="store_true", - help="Disable Q/K L2 normalization", + flashinfer_median_us = np.median(flashinfer_times) * 1000 + + # ========== Triton Benchmark ========== + # State [B, HV, K, V] + state_tr = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", ) - parser.add_argument( - "--version", - type=str, - choices=["pretranspose", "nontranspose", "mtp", "all"], - default="nontranspose", - help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), or all", + output_tr = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" ) - parser.add_argument( - "--seq-len", - type=int, - nargs="+", - default=[2, 4, 8], - help="Sequence lengths for MTP benchmark (T > 1)", + + triton_times = bench_gpu_time( + lambda: triton_gdn_decode( + q, k, v, state_tr, A_log, a, dt_bias, b, scale, output_tr, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, ) - parser.add_argument( - "--cache-intermediate-states", - action="store_true", - help="Cache intermediate states for MTP benchmark", + triton_median_us = np.median(triton_times) * 1000 + + # Calculate metrics + flops = gdn_decode_flops( + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size ) - parser.add_argument( - "--warmup", - type=int, - default=10, - help="Number of warmup iterations", + + flashinfer_tflops = ( + flops / (flashinfer_median_us / 1000) / 1e9 if flashinfer_median_us > 0 else 0 ) - parser.add_argument( - "--iters", - type=int, - default=100, - help="Number of benchmark iterations", + triton_tflops = ( + flops / (triton_median_us / 1000) / 1e9 if triton_median_us > 0 else 0 ) - args = parser.parse_args() - # Apply preset configurations - if args.preset == "qwen3-next": - # Qwen3-Next-80B-A3B linear attention config (GVA) - args.num_q_heads = 16 - args.num_k_heads = 16 - args.num_v_heads = 32 - args.head_size = 128 + speedup = triton_median_us / flashinfer_median_us if flashinfer_median_us > 0 else 0 - # Check SM90 support - device_capability = torch.cuda.get_device_capability() - if device_capability[0] < 9: - print(f"Current device capability: {device_capability}") - print("GDN requires SM90 (Hopper) or later. Exiting...") - return + return { + "batch_size": batch_size, + "flashinfer_us": flashinfer_median_us, + "triton_us": triton_median_us, + "flashinfer_tflops": flashinfer_tflops, + "triton_tflops": triton_tflops, + "speedup": speedup, + } - dtype = getattr(torch, args.dtype) - use_qk_l2norm = not args.no_qk_l2norm - # Determine which versions to benchmark - if args.version == "all": - versions_to_bench = ["pretranspose", "nontranspose", "mtp"] - else: - versions_to_bench = [args.version] +def bench_comparison_pretranspose( + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + warmup_iters: int = 10, + bench_iters: int = 100, +): + """Benchmark both FlashInfer and Triton pretranspose implementations.""" + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. Install with: pip install triton") - for version in versions_to_bench: - if version == "mtp": - # Benchmark MTP version - print( - f"\nGDN MTP Benchmark " - f"(heads: q={args.num_q_heads}, k={args.num_k_heads}, " + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs (T=1 for decode) + T = 1 + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # ========== FlashInfer Benchmark ========== + # State for FlashInfer pretranspose (V-major layout) [B, HV, V, K] + state_fi = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output_fi = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + flashinfer_times = bench_gpu_time( + lambda: gated_delta_rule_decode_pretranspose( + q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + flashinfer_median_us = np.median(flashinfer_times) * 1000 + + # ========== Triton Benchmark ========== + # State [B, HV, V, K] - pretranspose layout + state_tr = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output_tr = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + triton_times = bench_gpu_time( + lambda: triton_gdn_decode_pretranspose( + q, k, v, state_tr, A_log, a, dt_bias, b, scale, output_tr, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + triton_median_us = np.median(triton_times) * 1000 + + # Calculate metrics + flops = gdn_decode_flops( + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size + ) + + flashinfer_tflops = ( + flops / (flashinfer_median_us / 1000) / 1e9 if flashinfer_median_us > 0 else 0 + ) + triton_tflops = ( + flops / (triton_median_us / 1000) / 1e9 if triton_median_us > 0 else 0 + ) + + speedup = triton_median_us / flashinfer_median_us if flashinfer_median_us > 0 else 0 + + return { + "batch_size": batch_size, + "flashinfer_us": flashinfer_median_us, + "triton_us": triton_median_us, + "flashinfer_tflops": flashinfer_tflops, + "triton_tflops": triton_tflops, + "speedup": speedup, + } + + +def bench_mtp_comparison( + batch_size: int, + seq_len: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + cache_intermediate_states: bool = False, + warmup_iters: int = 10, + bench_iters: int = 100, +): + """Benchmark both FlashInfer and Triton MTP implementations.""" + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. Install with: pip install triton") + + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs + T = seq_len + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # Pool size = batch size for this benchmark + pool_size = batch_size + + # ========== FlashInfer Benchmark ========== + # State for FlashInfer (K-last layout) [pool_size, HV, V, K] + state_fi = torch.randn( + pool_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output_fi = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + initial_state_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda") + + # Intermediate states buffer + if cache_intermediate_states: + intermediate_fi = torch.zeros( + pool_size, + T, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + else: + intermediate_fi = None + + flashinfer_times = bench_gpu_time( + lambda: gated_delta_rule_mtp( + q, + k, + v, + state_fi, + initial_state_indices, + A_log, + a, + dt_bias, + b, + scale, + output_fi, + intermediate_fi, + disable_state_update=True, + use_qk_l2norm=use_qk_l2norm, + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + flashinfer_median_us = np.median(flashinfer_times) * 1000 + + # ========== Triton Benchmark ========== + # State for Triton [pool_size, HV, V, K] + state_tr = torch.randn( + pool_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output_tr = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + if cache_intermediate_states: + intermediate_tr = torch.zeros( + pool_size, + T, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + else: + intermediate_tr = None + + triton_times = bench_gpu_time( + lambda: triton_gdn_mtp( + q, + k, + v, + state_tr, + initial_state_indices, + A_log, + a, + dt_bias, + b, + scale, + output_tr, + intermediate_tr, + disable_state_update=True, + use_qk_l2norm=use_qk_l2norm, + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + triton_median_us = np.median(triton_times) * 1000 + + # Calculate metrics + flops = gdn_decode_flops( + batch_size, num_q_heads, num_k_heads, num_v_heads, head_size, seq_len + ) + + flashinfer_tflops = ( + flops / (flashinfer_median_us / 1000) / 1e9 if flashinfer_median_us > 0 else 0 + ) + triton_tflops = ( + flops / (triton_median_us / 1000) / 1e9 if triton_median_us > 0 else 0 + ) + + speedup = triton_median_us / flashinfer_median_us if flashinfer_median_us > 0 else 0 + + return { + "batch_size": batch_size, + "seq_len": seq_len, + "flashinfer_us": flashinfer_median_us, + "triton_us": triton_median_us, + "flashinfer_tflops": flashinfer_tflops, + "triton_tflops": triton_tflops, + "speedup": speedup, + } + + +def verify_correctness( + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + rtol: float = 1e-2, + atol: float = 1e-2, +): + """Verify FlashInfer and Triton produce similar results.""" + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. Install with: pip install triton") + + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs (T=1 for decode) + T = 1 + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # Same initial state for both + state_init = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + + # FlashInfer + state_fi = state_init.clone() + output_fi = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + gated_delta_rule_decode( + q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm + ) + + # Triton + state_tr = state_init.clone() + output_tr = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + triton_gdn_decode( + q, k, v, state_tr, A_log, a, dt_bias, b, scale, output_tr, use_qk_l2norm + ) + + # Compare outputs using torch.testing.assert_close + try: + torch.testing.assert_close( + output_fi.float(), output_tr.float(), rtol=rtol, atol=atol + ) + output_close = True + except AssertionError as e: + output_close = False + print(f" Output mismatch: {e}") + + try: + torch.testing.assert_close( + state_fi.float(), state_tr.float(), rtol=rtol, atol=atol + ) + state_close = True + except AssertionError as e: + state_close = False + print(f" State mismatch: {e}") + + return output_close and state_close + + +def verify_correctness_pretranspose( + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + rtol: float = 1e-2, + atol: float = 1e-2, +): + """Verify FlashInfer and Triton pretranspose produce similar results.""" + if not TRITON_AVAILABLE: + raise RuntimeError("Triton is not available. Install with: pip install triton") + + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs (T=1 for decode) + T = 1 + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + # Scale factor + scale = 1.0 / (head_size**0.5) + + # Same initial state for both [B, HV, V, K] - pretranspose layout + state_init = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + + # FlashInfer + state_fi = state_init.clone() + output_fi = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + gated_delta_rule_decode_pretranspose( + q, k, v, state_fi, A_log, a, dt_bias, b, scale, output_fi, use_qk_l2norm + ) + + # Triton + state_tr = state_init.clone() + output_tr = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + triton_gdn_decode_pretranspose( + q, k, v, state_tr, A_log, a, dt_bias, b, scale, output_tr, use_qk_l2norm + ) + + # Compare outputs using torch.testing.assert_close + try: + torch.testing.assert_close( + output_fi.float(), output_tr.float(), rtol=rtol, atol=atol + ) + output_close = True + except AssertionError as e: + output_close = False + print(f" Output mismatch: {e}") + + try: + torch.testing.assert_close( + state_fi.float(), state_tr.float(), rtol=rtol, atol=atol + ) + state_close = True + except AssertionError as e: + state_close = False + print(f" State mismatch: {e}") + + return output_close and state_close + + +# ============================================================================ +# All Layouts Comparison Benchmark +# ============================================================================ + + +def format_time(t): + """Format time value, returning 'N/A' if None.""" + return f"{t:>8.2f}" if t is not None else " N/A" + + +def format_speedup(base, other): + """Calculate and format speedup.""" + if base is None or other is None or base == 0: + return " N/A" + return f"{other / base:>7.2f}x" + + +def bench_all_layouts( + batch_size: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_qk_l2norm: bool = True, + warmup_iters: int = 10, + bench_iters: int = 100, +): + """Benchmark all 4 implementations: FlashInfer/Triton x pretranspose/nontranspose.""" + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs (T=1 for decode) + T = 1 + q = torch.randn(batch_size, T, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(batch_size, T, num_k_heads, head_size, dtype=dtype, device="cuda") + v = torch.randn(batch_size, T, num_v_heads, head_size, dtype=dtype, device="cuda") + + # GDN-specific parameters + A_log = torch.randn(num_sab_heads, dtype=torch.float32, device="cuda") + a = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + dt_bias = torch.randn(num_sab_heads, dtype=dtype, device="cuda") + b = torch.randn(batch_size, T, num_sab_heads, dtype=dtype, device="cuda") + + scale = 1.0 / (head_size**0.5) + + results = {"batch_size": batch_size} + + # ========== FlashInfer Pretranspose ========== + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + try: + times = bench_gpu_time( + lambda: gated_delta_rule_decode_pretranspose( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + results["fi_pretrans_us"] = np.median(times) * 1000 + except Exception as e: + results["fi_pretrans_us"] = None + print(f" FlashInfer pretranspose failed: {type(e).__name__}") + + # ========== FlashInfer Nontranspose ========== + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + try: + times = bench_gpu_time( + lambda: gated_delta_rule_decode( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + results["fi_nontrans_us"] = np.median(times) * 1000 + except Exception as e: + results["fi_nontrans_us"] = None + print(f" FlashInfer nontranspose failed: {type(e).__name__}") + + # ========== Triton Pretranspose ========== + if TRITON_AVAILABLE: + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + try: + times = bench_gpu_time( + lambda: triton_gdn_decode_pretranspose( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + results["tr_pretrans_us"] = np.median(times) * 1000 + except Exception as e: + results["tr_pretrans_us"] = None + print(f" Triton pretranspose failed: {type(e).__name__}") + + # ========== Triton Nontranspose ========== + state = torch.randn( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + output = torch.empty( + batch_size, T, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + + try: + times = bench_gpu_time( + lambda: triton_gdn_decode( + q, k, v, state, A_log, a, dt_bias, b, scale, output, use_qk_l2norm + ), + enable_cupti=True, + dry_run_iters=warmup_iters, + repeat_iters=bench_iters, + ) + results["tr_nontrans_us"] = np.median(times) * 1000 + except Exception as e: + results["tr_nontrans_us"] = None + print(f" Triton nontranspose failed: {type(e).__name__}") + else: + results["tr_pretrans_us"] = None + results["tr_nontrans_us"] = None + + return results + + +def run_all_layouts_benchmark(args, dtype, use_qk_l2norm): + """Run benchmark comparing all layouts: FlashInfer/Triton x pretranspose/nontranspose.""" + # Verify correctness first if requested + if args.verify and TRITON_AVAILABLE: + print("\n=== Correctness Verification ===") + for batch_size in [8, 16, 32, 64]: + print(f"Batch={batch_size}:") + # Pretranspose + try: + passed = verify_correctness_pretranspose( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + ) + print(f" Pretranspose: {'PASS' if passed else 'FAIL'}") + except Exception as e: + print(f" Pretranspose: ERROR - {type(e).__name__}") + # Nontranspose + try: + passed = verify_correctness( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + ) + print(f" Nontranspose: {'PASS' if passed else 'FAIL'}") + except Exception as e: + print(f" Nontranspose: ERROR - {type(e).__name__}") + print() + + print("\n" + "=" * 120) + print("GDN Decode Benchmark: FlashInfer vs Triton, Pretranspose vs Nontranspose") + print( + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " + f"v_heads={args.num_v_heads}, head_size={args.head_size}, " + f"dtype={args.dtype}, qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" + ) + print("=" * 120) + print() + print( + f"{'batch':>6} | {'FI-PreTr':>8} {'FI-NonTr':>8} | {'TR-PreTr':>8} {'TR-NonTr':>8} | " + f"{'FI/TR-Pre':>9} {'FI/TR-Non':>9} | {'Pre/Non-FI':>10} {'Pre/Non-TR':>10}" + ) + print( + f"{'':>6} | {'(us)':>8} {'(us)':>8} | {'(us)':>8} {'(us)':>8} | " + f"{'speedup':>9} {'speedup':>9} | {'speedup':>10} {'speedup':>10}" + ) + print("-" * 120) + + all_results = [] + for batch_size in args.batch_size: + result = bench_all_layouts( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + all_results.append(result) + + fi_pre = result.get("fi_pretrans_us") + fi_non = result.get("fi_nontrans_us") + tr_pre = result.get("tr_pretrans_us") + tr_non = result.get("tr_nontrans_us") + + # FI/TR speedup (>1 means FI faster) + fi_tr_pre = format_speedup(fi_pre, tr_pre) + fi_tr_non = format_speedup(fi_non, tr_non) + + # Pre/Non speedup (>1 means pretranspose faster) + pre_non_fi = format_speedup(fi_pre, fi_non) + pre_non_tr = format_speedup(tr_pre, tr_non) + + print( + f"{batch_size:>6} | {format_time(fi_pre)} {format_time(fi_non)} | " + f"{format_time(tr_pre)} {format_time(tr_non)} | " + f"{fi_tr_pre} {fi_tr_non} | {pre_non_fi} {pre_non_tr}" + ) + + print("-" * 120) + print() + print("Legend:") + print(" FI-PreTr = FlashInfer Pretranspose [B, HV, V, K]") + print(" FI-NonTr = FlashInfer Nontranspose [B, HV, K, V]") + print(" TR-PreTr = Triton Pretranspose [B, HV, V, K]") + print(" TR-NonTr = Triton Nontranspose [B, HV, K, V]") + print(" FI/TR speedup > 1.0 means FlashInfer is faster than Triton") + print(" Pre/Non speedup > 1.0 means Pretranspose is faster than Nontranspose") + print() + + # Summary statistics + fi_pre_times = [r["fi_pretrans_us"] for r in all_results if r.get("fi_pretrans_us")] + tr_pre_times = [r["tr_pretrans_us"] for r in all_results if r.get("tr_pretrans_us")] + + if fi_pre_times and tr_pre_times: + speedups = [tr / fi for fi, tr in zip(fi_pre_times, tr_pre_times, strict=False)] + print( + f"FlashInfer vs Triton (Pretranspose) - Average speedup: {np.mean(speedups):.2f}x" + ) + + +# ============================================================================ +# Main Entry Points +# ============================================================================ + + +def run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm): + """Run FlashInfer-only benchmarks.""" + # Determine which versions to benchmark + if args.version == "all": + versions_to_bench = ["pretranspose", "nontranspose", "mtp"] + else: + versions_to_bench = [args.version] + + for version in versions_to_bench: + if version == "mtp": + # Benchmark MTP version + print( + f"\nGDN MTP Benchmark " + f"(heads: q={args.num_q_heads}, k={args.num_k_heads}, " f"v={args.num_v_heads}, d={args.head_size}, dtype={args.dtype}, " f"qk_l2norm={'ON' if use_qk_l2norm else 'OFF'})" ) @@ -679,5 +2169,307 @@ def main(): print("-" * 90) +def run_comparison_benchmark(args, dtype, use_qk_l2norm): + """Run comparison benchmarks (FlashInfer vs Triton).""" + if not TRITON_AVAILABLE: + print("Error: Triton is not available. Install with: pip install triton") + return + + # Verify correctness first if requested + if args.verify: + version_name = args.version.upper() if args.version != "all" else "NONTRANSPOSE" + print(f"\n=== Correctness Verification ({version_name}) ===") + # Use larger batch sizes to avoid alignment issues with small batches + for batch_size in [8, 16, 32, 64]: + try: + if args.version == "pretranspose": + passed = verify_correctness_pretranspose( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + ) + else: + passed = verify_correctness( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + ) + status = "PASS" if passed else "FAIL" + print(f"Batch={batch_size}: {status}") + except Exception as e: + print(f"Batch={batch_size}: ERROR - {type(e).__name__}") + print() + + if args.version == "mtp": + # MTP comparison + print("\nGDN MTP Comparison: FlashInfer (CuTe DSL) vs Triton") + print( + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " + f"v_heads={args.num_v_heads}, head_size={args.head_size}, dtype={args.dtype}, " + f"qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}, " + f"cache_intermediate={'ON' if args.cache_intermediate_states else 'OFF'}" + ) + print("-" * 110) + print( + f"{'batch':>6} {'seq_len':>8} {'FlashInfer(us)':>14} {'Triton(us)':>12} " + f"{'FI TFLOPS':>10} {'TR TFLOPS':>10} {'Speedup':>10}" + ) + print("-" * 110) + + results = [] + for batch_size in args.batch_size: + for seq_len in args.seq_len: + result = bench_mtp_comparison( + batch_size=batch_size, + seq_len=seq_len, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + cache_intermediate_states=args.cache_intermediate_states, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + results.append(result) + + print( + f"{result['batch_size']:>6} {result['seq_len']:>8} " + f"{result['flashinfer_us']:>14.2f} {result['triton_us']:>12.2f} " + f"{result['flashinfer_tflops']:>10.2f} {result['triton_tflops']:>10.2f} " + f"{result['speedup']:>10.2f}x" + ) + + print("-" * 110) + elif args.version == "pretranspose": + # Pretranspose decode comparison + print("\nGDN Decode Comparison (PRETRANSPOSE): FlashInfer (CuTe DSL) vs Triton") + print( + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " + f"v_heads={args.num_v_heads}, head_size={args.head_size}, dtype={args.dtype}, " + f"qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" + ) + print("-" * 100) + print( + f"{'batch':>6} {'FlashInfer(us)':>14} {'Triton(us)':>12} " + f"{'FI TFLOPS':>10} {'TR TFLOPS':>10} {'Speedup':>10}" + ) + print("-" * 100) + + results = [] + for batch_size in args.batch_size: + result = bench_comparison_pretranspose( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + results.append(result) + + print( + f"{result['batch_size']:>6} {result['flashinfer_us']:>14.2f} " + f"{result['triton_us']:>12.2f} {result['flashinfer_tflops']:>10.2f} " + f"{result['triton_tflops']:>10.2f} {result['speedup']:>10.2f}x" + ) + + print("-" * 100) + else: + # Nontranspose decode comparison + print("\nGDN Decode Comparison (NONTRANSPOSE): FlashInfer (CuTe DSL) vs Triton") + print( + f"Config: q_heads={args.num_q_heads}, k_heads={args.num_k_heads}, " + f"v_heads={args.num_v_heads}, head_size={args.head_size}, dtype={args.dtype}, " + f"qk_l2norm={'ON' if use_qk_l2norm else 'OFF'}" + ) + print("-" * 100) + print( + f"{'batch':>6} {'FlashInfer(us)':>14} {'Triton(us)':>12} " + f"{'FI TFLOPS':>10} {'TR TFLOPS':>10} {'Speedup':>10}" + ) + print("-" * 100) + + results = [] + for batch_size in args.batch_size: + result = bench_comparison( + batch_size=batch_size, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + use_qk_l2norm=use_qk_l2norm, + warmup_iters=args.warmup, + bench_iters=args.iters, + ) + results.append(result) + + print( + f"{result['batch_size']:>6} {result['flashinfer_us']:>14.2f} " + f"{result['triton_us']:>12.2f} {result['flashinfer_tflops']:>10.2f} " + f"{result['triton_tflops']:>10.2f} {result['speedup']:>10.2f}x" + ) + + print("-" * 100) + + print("Speedup > 1.0 means FlashInfer is faster") + + # Print summary + speedups = [r["speedup"] for r in results] + min_idx = speedups.index(min(speedups)) + max_idx = speedups.index(max(speedups)) + print("\nSummary:") + print(f" Average speedup: {np.mean(speedups):.2f}x") + if args.version == "mtp": + print( + f" Min speedup: {speedups[min_idx]:.2f}x " + f"(batch={results[min_idx]['batch_size']}, T={results[min_idx]['seq_len']})" + ) + print( + f" Max speedup: {speedups[max_idx]:.2f}x " + f"(batch={results[max_idx]['batch_size']}, T={results[max_idx]['seq_len']})" + ) + else: + print( + f" Min speedup: {speedups[min_idx]:.2f}x (batch={results[min_idx]['batch_size']})" + ) + print( + f" Max speedup: {speedups[max_idx]:.2f}x (batch={results[max_idx]['batch_size']})" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="GDN Decode Benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Default: All layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + python benchmarks/bench_gdn_decode.py --batch-size 1 4 8 16 32 64 128 256 512 + + # Single layout comparison: FlashInfer vs Triton (nontranspose) + python benchmarks/bench_gdn_decode.py --compare --batch-size 1 4 8 16 32 64 128 256 512 + + # Single layout comparison: FlashInfer vs Triton (pretranspose) + python benchmarks/bench_gdn_decode.py --compare --version pretranspose --batch-size 1 4 8 16 32 64 128 256 512 + + # MTP benchmark (FlashInfer only) + python benchmarks/bench_gdn_decode.py --version mtp --batch-size 1 32 128 + + # MTP comparison: FlashInfer vs Triton + python benchmarks/bench_gdn_decode.py --version mtp --compare --batch-size 1 32 128 +""", + ) + parser.add_argument( + "--batch-size", + type=int, + nargs="+", + default=[1, 4, 8, 16, 32, 64, 128, 256, 512], + help="Batch sizes to benchmark (number of concurrent decode requests)", + ) + parser.add_argument("--num-q-heads", type=int, default=16) + parser.add_argument("--num-k-heads", type=int, default=16) + parser.add_argument("--num-v-heads", type=int, default=32) + parser.add_argument("--head-size", type=int, default=128) + parser.add_argument( + "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--preset", + type=str, + choices=["qwen3-next", "custom"], + default="custom", + help="Use preset config. qwen3-next: q=k=16, v=32, d=128", + ) + parser.add_argument( + "--no-qk-l2norm", + action="store_true", + help="Disable Q/K L2 normalization", + ) + parser.add_argument( + "--version", + type=str, + choices=["pretranspose", "nontranspose", "mtp", "all"], + default="nontranspose", + help="Kernel version: pretranspose (V-major state), nontranspose (K-major state), mtp (Multiple Token Processing), or all", + ) + parser.add_argument( + "--seq-len", + type=int, + nargs="+", + default=[2, 4, 8], + help="Sequence lengths for MTP benchmark (T > 1)", + ) + parser.add_argument( + "--cache-intermediate-states", + action="store_true", + help="Cache intermediate states for MTP benchmark", + ) + parser.add_argument( + "--warmup", + type=int, + default=10, + help="Number of warmup iterations", + ) + parser.add_argument( + "--iters", + type=int, + default=100, + help="Number of benchmark iterations", + ) + parser.add_argument( + "--compare", + action="store_true", + help="Run comparison benchmark: FlashInfer vs Triton", + ) + parser.add_argument( + "--verify", + action="store_true", + help="Run correctness verification before comparison benchmarking", + ) + args = parser.parse_args() + + # Apply preset configurations + if args.preset == "qwen3-next": + # Qwen3-Next-80B-A3B linear attention config (GVA) + args.num_q_heads = 16 + args.num_k_heads = 16 + args.num_v_heads = 32 + args.head_size = 128 + + # Check SM90 support + device_capability = torch.cuda.get_device_capability() + if device_capability[0] < 9: + print(f"Current device capability: {device_capability}") + print("GDN requires SM90 (Hopper) or later. Exiting...") + return + + dtype = getattr(torch, args.dtype) + use_qk_l2norm = not args.no_qk_l2norm + + if args.version == "mtp": + # MTP mode: use comparison or flashinfer-only + if args.compare: + run_comparison_benchmark(args, dtype, use_qk_l2norm) + else: + run_flashinfer_only_benchmark(args, dtype, use_qk_l2norm) + else: + # Non-MTP: always run all layouts comparison (FlashInfer/Triton x pretranspose/nontranspose) + run_all_layouts_benchmark(args, dtype, use_qk_l2norm) + + if __name__ == "__main__": main() diff --git a/flashinfer/gdn_decode.py b/flashinfer/gdn_decode.py index 476957a3fd..e64c231686 100644 --- a/flashinfer/gdn_decode.py +++ b/flashinfer/gdn_decode.py @@ -90,10 +90,43 @@ def flashinfer_api(func): # type: ignore[misc] # ============================================================================ # Global configuration for MTP (Multiple Token Processing) version # ============================================================================ -TILE_V_MTP = 8 -TILE_K_MTP = 128 -NUM_STAGES_MTP = 2 -NUM_THREADS_MTP = 128 # 4 warps +# Dynamic kernel selection based on batch size: +# - Small batch (B <= threshold): Use smaller TILE_V for more parallelism +# Optimal TILE_V depends on batch size (empirically determined): +# B=1-2: TILE_V=4 (more blocks, better parallelism for tiny batches) +# B=4: TILE_V=8 (intermediate parallelism) +# B=8: TILE_V=16 (balance between parallelism and efficiency) +# B>=16: TILE_V=32 (fewer blocks, better efficiency for large batches) + +TILE_K_MTP = 128 # Full K dimension (shared across all configs) +NUM_THREADS_MTP = 128 # 4 warps (shared across all configs) + + +def get_vec_size_mtp(batch_size: int, seq_len: int = 1) -> int: + """Select vec_size for MTP kernel. + + Always use vec_size=4 (32 threads per group = full warp, 4 groups per block). + Full warp shuffle is more efficient and achieves >= 1.0x speedup vs Triton. + """ + return 4 + + +def get_tile_v_mtp(batch_size: int, seq_len: int = 1) -> int: + """Select optimal TILE_V for MTP kernel based on batch size and sequence length. + + With vec_size=4, num_groups=4, rows_per_group = tile_v / 4. + Tuned via grid search for optimal performance. + """ + if batch_size <= 2: + return 4 # Small batch needs max parallelism + elif batch_size <= 4: + return 8 + elif batch_size <= 8: + return 16 + elif batch_size <= 16: + return 32 + else: + return 64 @cute.kernel @@ -169,6 +202,19 @@ def gdn_decode_kernel_small_batch_pretranspose( r_h = cute.make_rmem_tensor( cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) + # BF16 register tensors for vectorized q, k, v loading + r_q_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_k_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_v_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + + # Compute k_start for contiguous access pattern + k_start = lane_id * vec_size cute.arch.barrier() @@ -201,12 +247,22 @@ def gdn_decode_kernel_small_batch_pretranspose( cute.copy(tiled_copy_load, thr_gSrc, thr_sData) cute.arch.cp_async_commit_group() - for i in range(vec_size): - r_q[i] = cutlass.Float32(q[i_n, i_t, i_h, i * 32 + lane_id]) - r_k[i] = cutlass.Float32(k[i_n, i_t, i_h, i * 32 + lane_id]) - # Store v to shared memory instead of register - v_val = cutlass.Float32(v[i_n, i_t, i_hv, i * 32 + lane_id]) - sV[i * 32 + lane_id] = v_val + # Load q, k into BF16 registers using autovec_copy (contiguous pattern) + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + + # Convert BF16 to FP32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) cute.arch.barrier() # Ensure all threads finish writing to sV @@ -223,9 +279,9 @@ def gdn_decode_kernel_small_batch_pretranspose( if beta_x <= softplus_threshold: # softplus(x) = (1/beta) * log(1 + exp(beta*x)) # Compute in Float32 - exp_beta_x = cute.exp(beta_x) + exp_beta_x = cute.exp(beta_x, fastmath=True) log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input)) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) softplus_x = cutlass.Float32( (cutlass.Float32(1.0) / softplus_beta) * log_result ) @@ -233,13 +289,13 @@ def gdn_decode_kernel_small_batch_pretranspose( softplus_x = x # Compute g = exp(A_log) * softplus_x - r_g_value = -cute.exp(r_A_log) * softplus_x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x # Compute beta = 1 / (1 + exp(-b)) - r_beta = 1.0 / (1.0 + cute.exp(-r_b)) + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) # Store to scalar (Float32) - r_g = cute.exp(r_g_value) + r_g = cute.exp(r_g_value, fastmath=True) r_g = cute.arch.shuffle_sync(r_g, 0) r_beta = cute.arch.shuffle_sync(r_beta, 0) @@ -248,7 +304,7 @@ def gdn_decode_kernel_small_batch_pretranspose( # Compute L2 norm of q and k sum_q = 0.0 sum_k = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): sum_q += r_q[i] * r_q[i] sum_k += r_k[i] * r_k[i] # Warp-level reduction using butterfly shuffle @@ -260,14 +316,14 @@ def gdn_decode_kernel_small_batch_pretranspose( sum_k, offset=offset, mask=-1, mask_and_clamp=31 ) - norm_q = cute.sqrt(sum_q + 1e-6) - norm_k = cute.sqrt(sum_k + 1e-6) - for i in range(vec_size): - r_q[i] = r_q[i] / norm_q - r_k[i] = r_k[i] / norm_k + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k # Apply scaling in Float32 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): r_q[i] = r_q[i] * scale # =================================================================== @@ -295,12 +351,18 @@ def gdn_decode_kernel_small_batch_pretranspose( cute.copy(tiled_copy_load, thr_gSrc, thr_sData) cute.arch.cp_async_commit_group() - # Step 3: Compute using data from current stage - for row in range(0, TILE_V, 4): + # Step 3: Compute using data from current stage (contiguous access pattern) + for row in cutlass.range_constexpr(0, TILE_V, 4): row_offset = tidx // 32 sum_hk = 0.0 - for i in range(vec_size): - r_h[i] = sData[(row + row_offset, i * 32 + lane_id, stage)] + + # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) + ) + cute.autovec_copy(sData_tile, r_h) + + for i in cutlass.range_constexpr(vec_size): r_h[i] = r_h[i] * r_g sum_hk += r_h[i] * r_k[i] @@ -313,11 +375,16 @@ def gdn_decode_kernel_small_batch_pretranspose( v_new = v_new * r_beta sum_hq = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): r_h[i] += r_k[i] * v_new - gDst[(0, row + row_offset, i * 32 + lane_id, v_tiles)] = r_h[i] sum_hq += r_h[i] * r_q[i] + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) + ) + cute.autovec_copy(r_h, gDst_tile) + for offset in [16, 8, 4, 2, 1]: sum_hq += cute.arch.shuffle_sync_bfly( sum_hq, offset=offset, mask=-1, mask_and_clamp=31 @@ -406,6 +473,19 @@ def gdn_decode_kernel_big_batch_pretranspose( r_h = cute.make_rmem_tensor( cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) + # BF16 register tensors for vectorized q, k, v loading + r_q_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_k_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_v_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + + # Compute k_start for contiguous access pattern + k_start = lane_id * vec_size cute.arch.barrier() @@ -437,12 +517,22 @@ def gdn_decode_kernel_big_batch_pretranspose( cute.copy(tiled_copy_load, thr_gSrc, thr_sData) cute.arch.cp_async_commit_group() - for i in range(vec_size): - r_q[i] = cutlass.Float32(q[i_n, i_t, i_h, i * 32 + lane_id]) - r_k[i] = cutlass.Float32(k[i_n, i_t, i_h, i * 32 + lane_id]) - # Store v to shared memory instead of register - v_val = cutlass.Float32(v[i_n, i_t, i_hv, i * 32 + lane_id]) - sV[i * 32 + lane_id] = v_val + # Load q, k into BF16 registers using autovec_copy (contiguous pattern) + q_tile = cute.local_tile(q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + k_tile = cute.local_tile(k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_id)) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) + + # Convert BF16 to FP32 + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) + + # Load v into BF16 registers using autovec_copy, convert to FP32, store to sV + v_tile = cute.local_tile(v, (1, 1, 1, vec_size), (i_n, i_t, i_hv, lane_id)) + cute.autovec_copy(v_tile, r_v_bf16) + for i in cutlass.range_constexpr(vec_size): + sV[k_start + i] = cutlass.Float32(r_v_bf16[i]) cute.arch.barrier() # Ensure all threads finish writing to sV @@ -459,9 +549,9 @@ def gdn_decode_kernel_big_batch_pretranspose( if beta_x <= softplus_threshold: # softplus(x) = (1/beta) * log(1 + exp(beta*x)) # Compute in Float32 - exp_beta_x = cute.exp(beta_x) + exp_beta_x = cute.exp(beta_x, fastmath=True) log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input)) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) softplus_x = cutlass.Float32( (cutlass.Float32(1.0) / softplus_beta) * log_result ) @@ -469,13 +559,13 @@ def gdn_decode_kernel_big_batch_pretranspose( softplus_x = x # Compute g = exp(A_log) * softplus_x - r_g_value = -cute.exp(r_A_log) * softplus_x + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x # Compute beta = 1 / (1 + exp(-b)) - r_beta = 1.0 / (1.0 + cute.exp(-r_b)) + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) # Store to scalar (Float32) - r_g = cute.exp(r_g_value) + r_g = cute.exp(r_g_value, fastmath=True) r_g = cute.arch.shuffle_sync(r_g, 0) r_beta = cute.arch.shuffle_sync(r_beta, 0) @@ -484,7 +574,7 @@ def gdn_decode_kernel_big_batch_pretranspose( # Compute L2 norm of q and k sum_q = 0.0 sum_k = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): sum_q += r_q[i] * r_q[i] sum_k += r_k[i] * r_k[i] # Warp-level reduction using butterfly shuffle @@ -496,14 +586,14 @@ def gdn_decode_kernel_big_batch_pretranspose( sum_k, offset=offset, mask=-1, mask_and_clamp=31 ) - norm_q = cute.sqrt(sum_q + 1e-6) - norm_k = cute.sqrt(sum_k + 1e-6) - for i in range(vec_size): - r_q[i] = r_q[i] / norm_q - r_k[i] = r_k[i] / norm_k + inv_norm_q = cute.rsqrt(sum_q + 1e-6, fastmath=True) + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q + r_k[i] = r_k[i] * inv_norm_k # Apply scaling in Float32 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): r_q[i] = r_q[i] * scale # =================================================================== @@ -530,12 +620,18 @@ def gdn_decode_kernel_big_batch_pretranspose( cute.copy(tiled_copy_load, thr_gSrc, thr_sData) cute.arch.cp_async_commit_group() - # Step 3: Compute using data from current stage - for row in range(0, TILE_V, 4): + # Step 3: Compute using data from current stage (contiguous access pattern) + for row in cutlass.range_constexpr(0, TILE_V, 4): row_offset = tidx // 32 sum_hk = 0.0 - for i in range(vec_size): - r_h[i] = sData[(row + row_offset, i * 32 + lane_id, stage)] + + # Load h from sData using 3D local_tile + autovec_copy (contiguous in K) + sData_tile = cute.local_tile( + sData, (1, vec_size, 1), (row + row_offset, lane_id, stage) + ) + cute.autovec_copy(sData_tile, r_h) + + for i in cutlass.range_constexpr(vec_size): r_h[i] = r_h[i] * r_g sum_hk += r_h[i] * r_k[i] @@ -548,11 +644,16 @@ def gdn_decode_kernel_big_batch_pretranspose( v_new = v_new * r_beta sum_hq = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): r_h[i] += r_k[i] * v_new - gDst[(0, row + row_offset, i * 32 + lane_id, v_tiles)] = r_h[i] sum_hq += r_h[i] * r_q[i] + # Write h to gDst using 4D local_tile + autovec_copy (contiguous in K) + gDst_tile = cute.local_tile( + gDst, (1, 1, vec_size, 1), (0, row + row_offset, lane_id, v_tiles) + ) + cute.autovec_copy(r_h, gDst_tile) + for offset in [16, 8, 4, 2, 1]: sum_hq += cute.arch.shuffle_sync_bfly( sum_hq, offset=offset, mask=-1, mask_and_clamp=31 @@ -599,6 +700,7 @@ def run_gdn_decode_kernel_small_batch_pretranspose( is_varlen: cutlass.Constexpr[bool], stream: cuda.CUstream, ): + """Launch original pipelined kernel for small batch pretranspose.""" # h0_source: (B*HV, V, K) batch_size, v_dim, k_dim = ( h0_source.layout.shape[0], @@ -629,12 +731,6 @@ def run_gdn_decode_kernel_small_batch_pretranspose( TILE_K // 32 ) # Each thread in a warp processes this many elements (always 4 for TILE_K=128) - # print(f"Batched CP.ASYNC Load + Store (bypass L1 cache)") - # print(f" {batch_size} batches x {v_dim}x{k_dim} matrices") - # print(f" Tile: {TILE_V}x{TILE_K}, {num_v_tiles} tiles/batch") - # print(f" Threads: {NUM_THREADS} ({NUM_THREADS // 32} warps), vec_size: {vec_size}") - # print(f" Total: {total_data_mb:.1f} MB\n") - # Create SMEM layout smem_layout_staged = cute.make_layout( (TILE_V, TILE_K, NUM_STAGES), stride=(TILE_K, 1, TILE_V * TILE_K) @@ -925,23 +1021,20 @@ def gated_delta_rule_decode_pretranspose( # Convert state from [B, HV, V, K] to [B*HV, V, K] for kernel h0_source = state.reshape(B * HV, V, K) - # Create dummy tensors for unused parameters - h0_indices = torch.zeros(B, dtype=torch.int32, device=q.device) - cu_seqlens = torch.zeros(B + 1, dtype=torch.int32, device=q.device) - # Compile kernel with TVM FFI (cached) cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm) cache = _get_compiled_decode_kernel(*cache_key) + # Get or create h0_indices and cu_seqlens (cached per config) + if "h0_indices" not in cache or cache["h0_indices"].device != q.device: + cache["h0_indices"] = torch.zeros(B, dtype=torch.int32, device=q.device) + cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) + h0_indices = cache["h0_indices"] + cu_seqlens = cache["cu_seqlens"] + if "compiled" not in cache: stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # Choose kernel based on batch size - if B <= 32: - run_func = run_gdn_decode_kernel_small_batch_pretranspose - else: - run_func = run_gdn_decode_kernel_big_batch_pretranspose - # Convert tensors to CuTe format for compilation only h0_source_tensor = from_dlpack(h0_source, assumed_align=16) A_log_tensor = from_dlpack(A_log, assumed_align=16) @@ -955,6 +1048,12 @@ def gated_delta_rule_decode_pretranspose( h0_indices_tensor = from_dlpack(h0_indices, assumed_align=16) cu_seqlens_tensor = from_dlpack(cu_seqlens, assumed_align=16) + # Choose kernel based on batch size + if B <= 32: + run_func = run_gdn_decode_kernel_small_batch_pretranspose + else: + run_func = run_gdn_decode_kernel_big_batch_pretranspose + # Use TVM FFI to reduce runtime overhead compiled = cute.compile( run_func, @@ -994,8 +1093,10 @@ def gated_delta_rule_decode_pretranspose( h0_source, A_log, a, dt_bias, q, k, v, b, output, h0_indices, cu_seqlens, stream ) - # Reshape state back (no sync needed - PyTorch handles stream ordering) - state.copy_(h0_source.reshape(B, HV, V, K)) + # Copy state back only if state was not contiguous + # (if contiguous, reshape returns a view and kernel updated state in-place) + if not state.is_contiguous(): + state.copy_(h0_source.reshape(B, HV, V, K)) # Convert output to target dtype if needed (kernel outputs bfloat16) if output.dtype != target_dtype: @@ -1073,7 +1174,9 @@ def gdn_decode_kernel_small_batch_nontranspose( sK[tidx] = cutlass.Float32(k[i_n, 0, i_h, tidx]) sQ[tidx] = cutlass.Float32(q[i_n, 0, i_h, tidx]) - gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] + # Compute flat index for flattened state [B*HV, K, V] + flat_idx = pool_idx * HV + i_hv + gSrc_batch = h0_source[(flat_idx, None, None)] gSrc = cute.local_tile(gSrc_batch, (TILE_K_NT, TILE_V_SMALL_NT), (0, None)) thr_copy_load = tiled_copy_load.get_slice(tidx) @@ -1100,17 +1203,17 @@ def gdn_decode_kernel_small_batch_nontranspose( beta_x = softplus_beta * x softplus_x = 0.0 if beta_x <= softplus_threshold: - exp_beta_x = cute.exp(beta_x) + exp_beta_x = cute.exp(beta_x, fastmath=True) log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input)) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) softplus_x = cutlass.Float32( (cutlass.Float32(1.0) / softplus_beta) * log_result ) else: softplus_x = x - r_g_value = -cute.exp(r_A_log) * softplus_x - r_beta = 1.0 / (1.0 + cute.exp(-r_b)) - r_g = cute.exp(r_g_value) + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) + r_g = cute.exp(r_g_value, fastmath=True) r_g = cute.arch.shuffle_sync(r_g, 0) r_beta = cute.arch.shuffle_sync(r_beta, 0) @@ -1155,8 +1258,8 @@ def gdn_decode_kernel_small_batch_nontranspose( local_sum_k, offset=offset, mask=-1, mask_and_clamp=31 ) if in_warp_tid == 0: - smem_o[0] = cute.rsqrt(local_sum_q + 1e-6) - smem_o[1] = cute.rsqrt(local_sum_k + 1e-6) + smem_o[0] = cute.rsqrt(local_sum_q + 1e-6, fastmath=True) + smem_o[1] = cute.rsqrt(local_sum_k + 1e-6, fastmath=True) cute.arch.barrier() inv_norm_q = smem_o[0] @@ -1236,14 +1339,15 @@ def gdn_decode_kernel_small_batch_nontranspose( cute.arch.barrier() - for k_iter in range(NUM_K_ITERS_SMALL): - flat_idx = tidx + k_iter * 128 - k_write = flat_idx // TILE_V_SMALL_NT - v_write = flat_idx % TILE_V_SMALL_NT + for k_iter in cutlass.range_constexpr(NUM_K_ITERS_SMALL): + flat_tid = tidx + k_iter * 128 + k_write = flat_tid // TILE_V_SMALL_NT + v_write = flat_tid % TILE_V_SMALL_NT if k_write < TILE_K_NT: h_val = sData[(k_write, v_write, stage)] v_global_write = v_tile * TILE_V_SMALL_NT + v_write - h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val + # Use flat index for flattened state [B*HV, K, V] + h0_source[(flat_idx, k_write, v_global_write)] = h_val cute.arch.barrier() @@ -1301,7 +1405,9 @@ def gdn_decode_kernel_big_batch_nontranspose( sK[tidx] = cutlass.Float32(k[i_n, 0, i_h, tidx]) sQ[tidx] = cutlass.Float32(q[i_n, 0, i_h, tidx]) - gSrc_batch = h0_source[(pool_idx, i_hv, None, None)] + # Compute flat index for flattened state [B*HV, K, V] + flat_idx = pool_idx * HV + i_hv + gSrc_batch = h0_source[(flat_idx, None, None)] gSrc = cute.local_tile(gSrc_batch, (TILE_K_NT, TILE_V_NT), (0, None)) thr_copy_load = tiled_copy_load.get_slice(tidx) @@ -1327,17 +1433,17 @@ def gdn_decode_kernel_big_batch_nontranspose( beta_x = softplus_beta * x softplus_x = 0.0 if beta_x <= softplus_threshold: - exp_beta_x = cute.exp(beta_x) + exp_beta_x = cute.exp(beta_x, fastmath=True) log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input)) + log_result = cutlass.Float32(cute.log(log_input, fastmath=True)) softplus_x = cutlass.Float32( (cutlass.Float32(1.0) / softplus_beta) * log_result ) else: softplus_x = x - r_g_value = -cute.exp(r_A_log) * softplus_x - r_beta = 1.0 / (1.0 + cute.exp(-r_b)) - r_g = cute.exp(r_g_value) + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + r_beta = 1.0 / (1.0 + cute.exp(-r_b, fastmath=True)) + r_g = cute.exp(r_g_value, fastmath=True) r_g = cute.arch.shuffle_sync(r_g, 0) r_beta = cute.arch.shuffle_sync(r_beta, 0) @@ -1382,8 +1488,8 @@ def gdn_decode_kernel_big_batch_nontranspose( local_sum_k, offset=offset, mask=-1, mask_and_clamp=31 ) if in_warp_tid == 0: - smem_o[0] = cute.rsqrt(local_sum_q + 1e-6) - smem_o[1] = cute.rsqrt(local_sum_k + 1e-6) + smem_o[0] = cute.rsqrt(local_sum_q + 1e-6, fastmath=True) + smem_o[1] = cute.rsqrt(local_sum_k + 1e-6, fastmath=True) cute.arch.barrier() inv_norm_q = smem_o[0] @@ -1455,14 +1561,15 @@ def gdn_decode_kernel_big_batch_nontranspose( cute.arch.barrier() - for k_iter in range(NUM_K_ITERS_NT): - flat_idx = tidx + k_iter * 256 - k_write = flat_idx // TILE_V_NT - v_write = flat_idx % TILE_V_NT + for k_iter in cutlass.range_constexpr(NUM_K_ITERS_NT): + flat_tid = tidx + k_iter * 256 + k_write = flat_tid // TILE_V_NT + v_write = flat_tid % TILE_V_NT if k_write < TILE_K_NT: h_val = sData[(k_write, v_write, stage)] v_global_write = v_tile * TILE_V_NT + v_write - h0_source[(pool_idx, i_hv, k_write, v_global_write)] = h_val + # Use flat index for flattened state [B*HV, K, V] + h0_source[(flat_idx, k_write, v_global_write)] = h_val cute.arch.barrier() @@ -1493,9 +1600,10 @@ def run_gdn_decode_kernel_small_batch_nontranspose( use_qk_l2norm: cutlass.Constexpr[bool], stream: cuda.CUstream, ): - pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape - n_indices = h0_indices.layout.shape[0] - batch_size = n_indices * hv_dim + # h0_source is flattened to [B*HV, K, V] to ensure proper alignment for SIMT async copy + batch_hv_dim, k_dim, v_dim = h0_source.layout.shape + h0_indices.layout.shape[0] + batch_size = batch_hv_dim # batch_hv_dim = B * HV copy_atom = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), @@ -1573,9 +1681,10 @@ def run_gdn_decode_kernel_big_batch_nontranspose( use_qk_l2norm: cutlass.Constexpr[bool], stream: cuda.CUstream, ): - pool_size, hv_dim, k_dim, v_dim = h0_source.layout.shape - n_indices = h0_indices.layout.shape[0] - batch_size = n_indices * hv_dim + # h0_source is flattened to [B*HV, K, V] to ensure proper alignment for SIMT async copy + batch_hv_dim, k_dim, v_dim = h0_source.layout.shape + h0_indices.layout.shape[0] + batch_size = batch_hv_dim # batch_hv_dim = B * HV copy_atom = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), @@ -1727,18 +1836,23 @@ def gated_delta_rule_decode( # Kernel outputs bfloat16, allocate in that dtype first output = torch.zeros((B, T, HV, V), dtype=torch.bfloat16, device=q.device) - # State is already in K-major layout [B, HV, K, V] - # Use state directly as pooled view: pool_size=B, each batch is its own pool - h0_source = state.contiguous() - - # Create h0_indices: each batch points to its own pool index - h0_indices = torch.arange(B, dtype=torch.int32, device=q.device) - cu_seqlens = torch.zeros(B + 1, dtype=torch.int32, device=q.device) + # State is in K-major layout [B, HV, K, V] + # Flatten to [B*HV, K, V] to ensure proper alignment for SIMT async copy + # This avoids alignment issues when B=1 (zero strides cause alignment failures) + state_contiguous = state.contiguous() + h0_source = state_contiguous.view(B * HV, K, V) # Compile kernel with TVM FFI (cached) cache_key = (B, T, H, HV, K, V, q.dtype, scale, use_qk_l2norm) cache = _get_compiled_decode_kernel_nontranspose(*cache_key) + # Get or create h0_indices and cu_seqlens (cached per config) + if "h0_indices" not in cache or cache["h0_indices"].device != q.device: + cache["h0_indices"] = torch.arange(B, dtype=torch.int32, device=q.device) + cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) + h0_indices = cache["h0_indices"] + cu_seqlens = cache["cu_seqlens"] + if "compiled" not in cache: stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -1812,9 +1926,10 @@ def gated_delta_rule_decode( stream, ) - # Copy state back (no sync needed - PyTorch handles stream ordering) - # h0_source is already [B, HV, K, V] - state.copy_(h0_source) + # Copy state back only if state was not contiguous + # (if contiguous, state_contiguous is state itself, so kernel updated state in-place) + if state_contiguous.data_ptr() != state.data_ptr(): + state.copy_(state_contiguous) # Convert output to target dtype if needed (kernel outputs bfloat16) if output.dtype != target_dtype: @@ -1830,12 +1945,11 @@ def gated_delta_rule_decode( @cute.kernel def gdn_verify_kernel_mtp( - tiled_copy_load: cute.TiledCopy, h0_source: cute.Tensor, # [pool_size * HV, V, K] - initial state pool (K-last) intermediate_states: cute.Tensor, # [pool_size * T * HV, V, K] - intermediate state cache - smem_layout_staged: cute.ComposedLayout, # Swizzled layout to avoid bank conflicts vec_size: cutlass.Constexpr[int], num_v_tiles: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], # TILE_V - configurable for batch size A_log: cute.Tensor, # [HV] a: cute.Tensor, # [B, T, HV] dt_bias: cute.Tensor, # [HV] @@ -1862,44 +1976,61 @@ def gdn_verify_kernel_mtp( cache_intermediate_states: cutlass.Constexpr[bool], ): """ - Verify kernel with optimized loop order: (v_tiles outer, time_steps inner) + Parallel MTP kernel - each block handles one [TILE_V, TILE_K] tile. - Uses cute.make_rmem_tensor for register arrays (same style as decode kernel). - Store-compute overlap: cache data in registers, store during next iteration. - """ + Grid: (B * HV * num_v_tiles, 1, 1) + Each block: + - Loads its v_tile of state into registers + - Processes all T time steps with state in registers + - Writes output and optionally updates state + This matches Triton's parallelization strategy for better small-batch performance. + """ tidx, _, _ = cute.arch.thread_idx() lane_id = tidx % 32 warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Compute thread grouping based on vec_size: + # vec_size=8: 16 threads per group (half-warp), 8 groups per block + # vec_size=4: 32 threads per group (full warp), 4 groups per block + threads_per_group: cutlass.Constexpr[int] = K // vec_size # 16 or 32 + groups_per_warp: cutlass.Constexpr[int] = 32 // threads_per_group # 2 or 1 + num_groups: cutlass.Constexpr[int] = 4 * groups_per_warp # 8 or 4 + + # Lane position within group and group index + lane_in_group = lane_id % threads_per_group + group_in_warp = lane_id // threads_per_group + group_idx = warp_idx * groups_per_warp + group_in_warp + batch_idx, _, _ = cute.arch.block_idx() - i_n = batch_idx // HV - i_hv = batch_idx % HV + + # Decode block index: (i_n, i_hv, i_v) from batch_idx + i_v = batch_idx % num_v_tiles + tmp = batch_idx // num_v_tiles + i_hv = tmp % HV + i_n = tmp // HV i_h = i_hv // (HV // H) + # Get initial state index for this batch + cache_idx = h0_indices[i_n] + # Load A_log and dt_bias once (they don't vary with time) r_A_log = cutlass.Float32(A_log[i_hv]) r_dt_bias = cutlass.Float32(dt_bias[i_hv]) + # Allocate shared memory for pre-computed values (broadcast to all warps) smem = cutlass.utils.SmemAllocator() - - # Allocate shared memory with padding to avoid bank conflicts - sData = smem.allocate_tensor(cutlass.Float32, smem_layout_staged, 128) - sOutput = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T, V)), 16) - # Pre-computed shared memory for all time steps (with padding for bank conflict avoidance) sQ = smem.allocate_tensor( cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 ) sK = smem.allocate_tensor( cutlass.Float32, cute.make_layout((T, K), stride=(K + 8, 1)), 16 ) - sV = smem.allocate_tensor( - cutlass.Float32, cute.make_layout((T, V), stride=(V + 8, 1)), 16 - ) sG = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T,)), 16) sBeta = smem.allocate_tensor(cutlass.Float32, cute.make_layout((T,)), 16) - # Allocate register tensors + # Register arrays for computation r_q = cute.make_rmem_tensor( cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) @@ -1909,64 +2040,45 @@ def gdn_verify_kernel_mtp( r_h = cute.make_rmem_tensor( cute.make_layout((vec_size,), stride=(1,)), cutlass.Float32 ) + # BF16 register tensors for vectorized q, k loading + r_q_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) + r_k_bf16 = cute.make_rmem_tensor( + cute.make_layout((vec_size,), stride=(1,)), cutlass.BFloat16 + ) - # Initialize output accumulator to zero - for i_t in range(T): - sOutput[(i_t, tidx)] = 0.0 - - cute.arch.barrier() - - # Get initial state index for this batch - cache_idx = h0_indices[i_n] - - # Early exit optimization: skip pre-computation for padding slots + # Only process valid batch entries (cache_idx >= 0) if cache_idx >= 0: - # Pre-compute q, k, v, g, beta for all time steps (outside v_tiles loop) - for i_t in range(T): - # Load q, k into register arrays - for i in range(vec_size): - r_q[i] = cutlass.Float32(q[i_n, i_t, i_h, i * 32 + lane_id]) - r_k[i] = cutlass.Float32(k[i_n, i_t, i_h, i * 32 + lane_id]) - # Load v for all V elements - sV[(i_t, i * 32 + lane_id)] = cutlass.Float32( - v[i_n, i_t, i_hv, i * 32 + lane_id] - ) - - # Compute g and beta - r_a = cutlass.Float32(a[i_n, i_t, i_hv]) - r_b = cutlass.Float32(b[i_n, i_t, i_hv]) - r_g = 0.0 - r_beta = 0.0 - if lane_id == 0: - x = r_a + r_dt_bias - beta_x = softplus_beta * x - softplus_x = 0.0 - - if beta_x <= softplus_threshold: - exp_beta_x = cute.exp(beta_x) - log_input = cutlass.Float32(1.0 + exp_beta_x) - log_result = cutlass.Float32(cute.log(log_input)) - softplus_x = cutlass.Float32( - (cutlass.Float32(1.0) / softplus_beta) * log_result - ) - else: - softplus_x = x - - r_g_value = -cute.exp(r_A_log) * softplus_x - r_beta = 1.0 / (1.0 + cute.exp(-r_b)) - r_g = cute.exp(r_g_value) + # Compute k_start once (used for shared memory writes) + k_start = lane_in_group * vec_size + + # Pre-compute q, k, g, beta for ALL time steps ONCE (shared across warps) + for i_t in cutlass.range_constexpr(T): + # Load q, k into BF16 registers using autovec_copy (coalesced) + q_tile = cute.local_tile( + q, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group) + ) + k_tile = cute.local_tile( + k, (1, 1, 1, vec_size), (i_n, i_t, i_h, lane_in_group) + ) + cute.autovec_copy(q_tile, r_q_bf16) + cute.autovec_copy(k_tile, r_k_bf16) - r_g = cute.arch.shuffle_sync(r_g, 0) - r_beta = cute.arch.shuffle_sync(r_beta, 0) + # Convert BF16 to FP32 for computation + for i in cutlass.range_constexpr(vec_size): + r_q[i] = cutlass.Float32(r_q_bf16[i]) + r_k[i] = cutlass.Float32(r_k_bf16[i]) - # Apply L2 normalization - if use_qk_l2norm: + # Apply L2 normalization to q, k (with scale fused for q) + if cutlass.const_expr(use_qk_l2norm): sum_q = 0.0 sum_k = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): sum_q += r_q[i] * r_q[i] sum_k += r_k[i] * r_k[i] + # Warp-level reduction (32 threads per group with vec_size=4) for offset in [16, 8, 4, 2, 1]: sum_q += cute.arch.shuffle_sync_bfly( sum_q, offset=offset, mask=-1, mask_and_clamp=31 @@ -1975,150 +2087,138 @@ def gdn_verify_kernel_mtp( sum_k, offset=offset, mask=-1, mask_and_clamp=31 ) - inv_norm_q = cute.rsqrt(sum_q + 1e-6) - inv_norm_k = cute.rsqrt(sum_k + 1e-6) + # Fuse scale into q's normalization factor + inv_norm_q_scaled = cute.rsqrt(sum_q + 1e-6, fastmath=True) * scale + inv_norm_k = cute.rsqrt(sum_k + 1e-6, fastmath=True) - for i in range(vec_size): - r_q[i] = r_q[i] * inv_norm_q + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * inv_norm_q_scaled r_k[i] = r_k[i] * inv_norm_k + else: + # No L2 norm, just apply scale to q + for i in cutlass.range_constexpr(vec_size): + r_q[i] = r_q[i] * scale - # Apply scaling to q - for i in range(vec_size): - r_q[i] = r_q[i] * scale - - # Store pre-computed values to shared memory - for i in range(vec_size): - sQ[(i_t, i * 32 + lane_id)] = r_q[i] - sK[(i_t, i * 32 + lane_id)] = r_k[i] - - # Store g and beta (only one thread needs to do this) - if tidx == 0: - sG[i_t] = r_g - sBeta[i_t] = r_beta - - # All threads must participate in barrier (CUDA requirement) - cute.arch.barrier() + # Store to shared memory (only first group writes) - contiguous layout + if tidx < threads_per_group: + for i in cutlass.range_constexpr(vec_size): + sQ[(i_t, k_start + i)] = r_q[i] + sK[(i_t, k_start + i)] = r_k[i] - # Main computation only for valid batch entries - if cache_idx >= 0: - # Setup source tensor for initial state loading - gSrc_batch = h0_source[(cache_idx * HV + i_hv, None, None)] - gDst_h0 = cute.local_tile( - h0_source, (1, TILE_V_MTP, TILE_K_MTP), (cache_idx * HV + i_hv, None, 0) - ) - gSrc = cute.local_tile(gSrc_batch, (TILE_V_MTP, TILE_K_MTP), (None, 0)) + # Compute g, beta - all lanes compute (redundant but no divergence) + r_a = cutlass.Float32(a[i_n, i_t, i_hv]) + r_b = cutlass.Float32(b[i_n, i_t, i_hv]) - thr_copy_load = tiled_copy_load.get_slice(tidx) + x = r_a + r_dt_bias + beta_x = softplus_beta * x - # Main loop: v_tiles (outer) x time_steps (inner) - prefetch_count = cutlass.min(NUM_STAGES_MTP - 1, num_v_tiles) + # Branchless softplus + exp_beta_x = cute.exp(beta_x, fastmath=True) + softplus_val = (cutlass.Float32(1.0) / softplus_beta) * cute.log( + cutlass.Float32(1.0) + exp_beta_x, fastmath=True + ) + use_softplus = ( + cutlass.Float32(1.0) + if beta_x <= softplus_threshold + else cutlass.Float32(0.0) + ) + softplus_x = ( + use_softplus * softplus_val + (cutlass.Float32(1.0) - use_softplus) * x + ) - # Prefetch first v_tile(s) - for v_tiles in range(prefetch_count): - stage = v_tiles % NUM_STAGES_MTP - gSrc_tile = gSrc[(None, None, v_tiles)] - sData_stage = sData[(None, None, stage)] - thr_gSrc = thr_copy_load.partition_S(gSrc_tile) - thr_sData = thr_copy_load.partition_D(sData_stage) - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + r_g_value = -cute.exp(r_A_log, fastmath=True) * softplus_x + r_beta = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + cute.exp(-r_b, fastmath=True) + ) + r_g = cute.exp(r_g_value, fastmath=True) - # Process each v_tile - for v_tiles in range(num_v_tiles): - stage = v_tiles % NUM_STAGES_MTP + # Only thread 0 stores to shared memory + if tidx == 0: + sG[i_t] = r_g + sBeta[i_t] = r_beta - cute.arch.cp_async_wait_group(0) - cute.arch.barrier() + cute.arch.barrier() - # Prefetch next v_tile - next_v_tiles = v_tiles + prefetch_count - if next_v_tiles < num_v_tiles: - next_stage = next_v_tiles % NUM_STAGES_MTP - gSrc_next = gSrc[(None, None, next_v_tiles)] - sData_next = sData[(None, None, next_stage)] - thr_gSrc = thr_copy_load.partition_S(gSrc_next) - thr_sData = thr_copy_load.partition_D(sData_next) - cute.copy(tiled_copy_load, thr_gSrc, thr_sData) - cute.arch.cp_async_commit_group() + # Each group handles tile_v/num_groups V rows + rows_per_group: cutlass.Constexpr[int] = tile_v // num_groups + for row_in_group in cutlass.range_constexpr(rows_per_group): + v_idx = i_v * tile_v + group_idx * rows_per_group + row_in_group - # Inner loop: all time steps for this v_tile - for i_t in range(T): - # Load pre-computed values from shared memory - for i in range(vec_size): - r_q[i] = sQ[(i_t, i * 32 + lane_id)] - r_k[i] = sK[(i_t, i * 32 + lane_id)] + if v_idx < V: + # Load h[v_idx, :] into registers using 3D local_tile + autovec_copy + flat_state_idx = cache_idx * HV + i_hv + h_tile = cute.local_tile( + h0_source, (1, 1, vec_size), (flat_state_idx, v_idx, lane_in_group) + ) + cute.autovec_copy(h_tile, r_h) - r_g = sG[i_t] - r_beta = sBeta[i_t] + # Process all T time steps with h in registers + for i_t in cutlass.range_constexpr(T): + # Load pre-computed q, k from shared memory using 2D local_tile + sQ_tile = cute.local_tile(sQ, (1, vec_size), (i_t, lane_in_group)) + sK_tile = cute.local_tile(sK, (1, vec_size), (i_t, lane_in_group)) + cute.autovec_copy(sQ_tile, r_q) + cute.autovec_copy(sK_tile, r_k) - # Compute delta rule for this v_tile - for row in range(0, TILE_V_MTP, 4): - row_offset = tidx // 32 + r_g = sG[i_t] + r_beta = sBeta[i_t] - # Load h from sData, apply decay - for i in range(vec_size): - r_h[i] = ( - sData[(row + row_offset, i * 32 + lane_id, stage)] * r_g - ) + # Step 1: Apply decay to h + for i in cutlass.range_constexpr(vec_size): + r_h[i] = r_h[i] * r_g - # Compute sum_hk = h @ k + # Step 2: Compute sum_hk = h @ k (group reduction) sum_hk = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): sum_hk += r_h[i] * r_k[i] + # Warp-level reduction for offset in [16, 8, 4, 2, 1]: sum_hk += cute.arch.shuffle_sync_bfly( sum_hk, offset=offset, mask=-1, mask_and_clamp=31 ) - # Delta rule update - v_idx = v_tiles * TILE_V_MTP + row + row_offset - v_new = sV[(i_t, v_idx)] - sum_hk - v_new = v_new * r_beta + # Step 3: Load v for this v_idx and time step, apply delta rule + r_v = cutlass.Float32(v[i_n, i_t, i_hv, v_idx]) + v_new = (r_v - sum_hk) * r_beta - # Update h and write back to sData - for i in range(vec_size): + # Step 4: Update h: h += k * v_new + for i in cutlass.range_constexpr(vec_size): r_h[i] += r_k[i] * v_new - sData[(row + row_offset, i * 32 + lane_id, stage)] = r_h[i] - # Store intermediate state - if cache_intermediate_states: + # Cache intermediate state if needed using 3D local_tile + autovec_copy + if cutlass.const_expr(cache_intermediate_states): flat_idx = i_n * T * HV + i_t * HV + i_hv - if v_idx < V: - for i in range(vec_size): - intermediate_states[ - (flat_idx, v_idx, i * 32 + lane_id) - ] = r_h[i] + inter_tile = cute.local_tile( + intermediate_states, + (1, 1, vec_size), + (flat_idx, v_idx, lane_in_group), + ) + cute.autovec_copy(r_h, inter_tile) - # Compute sum_hq = h @ q (overlaps with store above) + # Step 5: Compute output: sum_hq = h @ q (group reduction) sum_hq = 0.0 - for i in range(vec_size): + for i in cutlass.range_constexpr(vec_size): sum_hq += r_h[i] * r_q[i] + # Warp-level reduction for offset in [16, 8, 4, 2, 1]: sum_hq += cute.arch.shuffle_sync_bfly( sum_hq, offset=offset, mask=-1, mask_and_clamp=31 ) - o_idx = v_tiles * TILE_V_MTP + row + row_offset - if lane_id == 0 and o_idx < V: - sOutput[(i_t, o_idx)] = cutlass.Float32(sum_hq) - - # Write final h for this v_tile to h0_source - if not disable_state_update: - for row in range(0, TILE_V_MTP, 4): - row_offset = tidx // 32 - for i in range(vec_size): - gDst_h0[(0, row + row_offset, i * 32 + lane_id, v_tiles)] = ( - sData[(row + row_offset, i * 32 + lane_id, stage)] - ) - - # Final writeback - cute.arch.barrier() + # Write output (only lane 0 of each group) + if lane_in_group == 0: + o[(i_n, i_t, i_hv, v_idx)] = cutlass.BFloat16(sum_hq) - for i_t in range(T): - if tidx < V: - o[(i_n, i_t, i_hv, tidx)] = cutlass.BFloat16(sOutput[(i_t, tidx)]) + # Write final state back (if not disabled) using 3D local_tile + autovec_copy + if cutlass.const_expr(not disable_state_update): + h_tile_out = cute.local_tile( + h0_source, + (1, 1, vec_size), + (flat_state_idx, v_idx, lane_in_group), + ) + cute.autovec_copy(r_h, h_tile_out) @cute.jit @@ -2144,6 +2244,8 @@ def run_gdn_verify_kernel_mtp( H: cutlass.Constexpr[int], K: cutlass.Constexpr[int], V: cutlass.Constexpr[int], + tile_v: cutlass.Constexpr[int], # TILE_V - configurable for batch size + vec_size: cutlass.Constexpr[int], # 4 for full warp, 8 for half-warp use_initial_state: cutlass.Constexpr[bool], use_qk_l2norm: cutlass.Constexpr[bool], is_varlen: cutlass.Constexpr[bool], @@ -2157,47 +2259,26 @@ def run_gdn_verify_kernel_mtp( h0_source.layout.shape[2], ) - num_v_tiles = cute.ceil_div(v_dim, TILE_V_MTP) - batch_size = B * HV - vec_size = TILE_K_MTP // 32 # Each thread in a warp processes this many elements + num_v_tiles = cute.ceil_div(v_dim, tile_v) - # Composed Swizzle Layout to avoid shared memory bank conflicts - base_smem_layout = cute.make_layout( - (TILE_V_MTP, TILE_K_MTP, NUM_STAGES_MTP), - stride=(TILE_K_MTP, 1, TILE_V_MTP * TILE_K_MTP), - ) - swizzle = cute.make_swizzle(2, 3, 7) # Swizzle<2, 3, 7> - smem_layout_staged = cute.make_composed_layout(swizzle, 0, base_smem_layout) - - # Create tiled copy for G2S load - copy_atom = cute.make_copy_atom( - cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), - cutlass.Float32, - num_bits_per_copy=128, - ) - thread_layout = cute.make_layout((4, 32), stride=(32, 1)) - val_layout = cute.make_layout((1, 4)) - tiled_copy_load = cute.make_tiled_copy_tv(copy_atom, thread_layout, val_layout) + # Grid: (B * HV * num_v_tiles, 1, 1) - parallelize across V dimension + grid_size = B * HV * num_v_tiles - # Calculate shared memory size + # Shared memory for pre-computed q, k, g, beta smem_bytes = ( - 4 * TILE_V_MTP * TILE_K_MTP * NUM_STAGES_MTP - + 4 * T * v_dim - + 4 * T * (k_dim + 8) - + 4 * T * (k_dim + 8) - + 4 * T * (v_dim + 8) - + 4 * T - + 4 * T - + 128 + 4 * T * (k_dim + 8) # sQ + + 4 * T * (k_dim + 8) # sK + + 4 * T # sG + + 4 * T # sBeta + + 128 # alignment ) gdn_verify_kernel_mtp( - tiled_copy_load, h0_source, intermediate_states, - smem_layout_staged, vec_size, num_v_tiles, + tile_v, A_log, a, dt_bias, @@ -2223,7 +2304,7 @@ def run_gdn_verify_kernel_mtp( disable_state_update, cache_intermediate_states, ).launch( - grid=(batch_size, 1, 1), + grid=(grid_size, 1, 1), block=[NUM_THREADS_MTP, 1, 1], smem=smem_bytes, stream=stream, @@ -2244,6 +2325,8 @@ def _get_compiled_mtp_kernel( cache_intermediate_states: bool, scale: float, use_qk_l2norm: bool, + tile_v: int, # TILE_V - configurable for batch size + vec_size: int, # 4 for full warp, 8 for half-warp ): """Cache compiled MTP kernel for given configuration.""" return {} @@ -2320,6 +2403,10 @@ def gated_delta_rule_mtp( _, _, HV, V = v.shape pool_size = initial_state.shape[0] + # Dynamic TILE_V and vec_size selection based on batch size and sequence length + tile_v = get_tile_v_mtp(B, T) + vec_size = get_vec_size_mtp(B, T) + # Validate state shape assert initial_state.shape == (pool_size, HV, V, K), ( f"Expected initial_state shape [pool_size={pool_size}, HV={HV}, V={V}, K={K}], got {initial_state.shape}" @@ -2328,8 +2415,8 @@ def gated_delta_rule_mtp( # Validate K and V constraints assert K >= 128, f"K must be at least 128, got K={K}" assert V >= 128, f"V must be at least 128, got V={V}" - assert V % TILE_V_MTP == 0, ( - f"V must be divisible by {TILE_V_MTP} to prevent out-of-bounds access, got V={V}" + assert V % tile_v == 0, ( + f"V must be divisible by {tile_v} to prevent out-of-bounds access, got V={V}" ) # Validate dtypes @@ -2375,9 +2462,6 @@ def gated_delta_rule_mtp( cache_steps = T intermediate_states = torch.zeros(1, 1, 1, dtype=torch.float32, device=q.device) - # Create cu_seqlens - cu_seqlens = torch.zeros(B + 1, dtype=torch.int32, device=q.device) - # Compile kernel with TVM FFI (cached) cache_key = ( B, @@ -2392,9 +2476,16 @@ def gated_delta_rule_mtp( cache_intermediate_states, scale, use_qk_l2norm, + tile_v, + vec_size, ) cache = _get_compiled_mtp_kernel(*cache_key) + # Get or create cu_seqlens (cached per config) + if "cu_seqlens" not in cache or cache["cu_seqlens"].device != q.device: + cache["cu_seqlens"] = torch.zeros(B + 1, dtype=torch.int32, device=q.device) + cu_seqlens = cache["cu_seqlens"] + if "compiled" not in cache: stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -2436,6 +2527,8 @@ def gated_delta_rule_mtp( H=H, K=K, V=V, + tile_v=tile_v, + vec_size=vec_size, use_initial_state=True, use_qk_l2norm=use_qk_l2norm, is_varlen=False, @@ -2467,7 +2560,9 @@ def gated_delta_rule_mtp( ) # Copy state back if needed (no sync needed - PyTorch handles stream ordering) - if not disable_state_update: + # Only copy if state update is enabled AND initial_state was not contiguous + # (if contiguous, reshape returns a view and kernel updated state in-place) + if not disable_state_update and not initial_state.is_contiguous(): initial_state.copy_(h0_source.reshape(pool_size, HV, V, K)) # Convert output to target dtype if needed