diff --git a/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py b/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py index 8e7d845f238a..ea124c487bdd 100644 --- a/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py +++ b/benchmark/bench_linear_attention/bench_cutedsl_kda_decode.py @@ -51,8 +51,11 @@ def make_inputs( b = torch.randn(B, HV, device=device, dtype=dtype) # prefill params for chunk_kda must keep batch dim = 1 - prefill_g = torch.randn(1, B, HV, K, device=device, dtype=dtype) - prefill_beta = torch.sigmoid(torch.randn(1, B, HV, device=device, dtype=dtype)) + # chunk_kda requires g, beta, v to have the same head count as k (H), + # matching the real KimiLinear model where num_heads == num_kv_heads. + prefill_v = torch.randn(1, B, H, V, device=device, dtype=dtype) + prefill_g = torch.randn(1, B, H, K, device=device, dtype=dtype) + prefill_beta = torch.sigmoid(torch.randn(1, B, H, device=device, dtype=dtype)) cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32) @@ -66,8 +69,11 @@ def make_inputs( b = torch.randn(B, 1, HV, device=device, dtype=dtype) # prefill params for chunk_kda dense path - prefill_g = torch.randn(B, 1, HV, K, device=device, dtype=dtype) - prefill_beta = torch.sigmoid(torch.randn(B, 1, HV, device=device, dtype=dtype)) + # chunk_kda requires g, beta, v to have the same head count as k (H), + # matching the real KimiLinear model where num_heads == num_kv_heads. + prefill_v = torch.randn(B, 1, H, V, device=device, dtype=dtype) + prefill_g = torch.randn(B, 1, H, K, device=device, dtype=dtype) + prefill_beta = torch.sigmoid(torch.randn(B, 1, H, device=device, dtype=dtype)) cu_seqlens = torch.arange(B + 1, device=device, dtype=torch.int32) else: @@ -94,6 +100,7 @@ def make_inputs( v=v, a=a, b=b, + prefill_v=prefill_v, prefill_g=prefill_g, prefill_beta=prefill_beta, A_log=A_log, @@ -147,12 +154,13 @@ def run_cutedsl(inp): def run_prefill_then_decode_baseline(inp): ssm_states = inp["ssm_states"].clone() + prefill_v_clone = inp["prefill_v"].clone() v_clone = inp["v"].clone() _ = chunk_kda( q=inp["q"], k=inp["k"], - v=v_clone, + v=prefill_v_clone, g=inp["prefill_g"], beta=inp["prefill_beta"], initial_state=ssm_states, @@ -182,12 +190,13 @@ def run_prefill_then_decode_baseline(inp): def run_prefill_then_decode_cutedsl(inp): ssm_states = inp["ssm_states"].clone() + prefill_v_clone = inp["prefill_v"].clone() v_clone = inp["v"].clone() _ = chunk_kda( q=inp["q"], k=inp["k"], - v=v_clone, + v=prefill_v_clone, g=inp["prefill_g"], beta=inp["prefill_beta"], initial_state=ssm_states, diff --git a/python/sglang/srt/layers/attention/fla/chunk_intra.py b/python/sglang/srt/layers/attention/fla/chunk_intra.py new file mode 100644 index 000000000000..344de6117ba4 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/chunk_intra.py @@ -0,0 +1,661 @@ +# Adapted from flash-linear-attention project. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.fla.chunk_intra_token_parallel import ( + chunk_kda_fwd_intra_token_parallel, +) +from sglang.srt.layers.attention.fla.index import ( + prepare_chunk_indices, +) +from sglang.srt.layers.attention.fla.op import exp2, gather +from sglang.srt.layers.attention.fla.utils import ( + autotune_cache_kwargs, + is_gather_supported, + is_tf32_supported, +) + +if is_tf32_supported: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") + + +################################################################################ +# Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass +################################################################################ + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps) + for BK in [32, 64] + for num_warps in [1, 2, 4] + ], + key=["H", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_inter_solve_fused( + q, + k, + g, + beta, + Aqk, + Akkd, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_SAFE_GATE: tl.constexpr, +): + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akkd. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akkd (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + Aqk += (bos * H + i_h) * BT + Akk += (bos * H + i_h) * BT + Akkd += (bos * H + i_h) * BC + + o_i = tl.arange(0, BC) + m_tc1 = (i_tc1 + o_i) < T + m_tc2 = (i_tc2 + o_i) < T + m_tc3 = (i_tc3 + o_i) < T + + b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + ################################################################################ + # off-diagonal blocks + ################################################################################ + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0) + ) + p_g0 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0) + ) + b_k0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_g0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + if i_tc1 < T: + p_q1 = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + p_k1 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + p_g1 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0) + ) + # [BC, BK] + b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn1 = tl.load(g + i_tc1 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_gqn = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn1[None, :] - b_g0)) + # [BC, BC] + b_Aqk10 += tl.dot(b_q1 * b_gqn, b_kgt) + b_Akk10 += tl.dot(b_k1 * b_gqn, b_kgt) + + if i_tc2 < T: + p_q2 = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + p_k2 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + p_g2 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0) + ) + # [BC, BK] + b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn2 = tl.load(g + i_tc2 * H * K + o_k, mask=m_k, other=0).to( + tl.float32 + ) + # [BC, BK] + b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) + b_qg2 = b_q2 * b_gqn2 + b_kg2 = b_k2 * b_gqn2 + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn2[None, :] - b_g0)) + b_Aqk20 += tl.dot(b_qg2, b_kgt) + b_Akk20 += tl.dot(b_kg2, b_kgt) + # [BC, BC] + b_kgt = tl.trans(b_k1 * exp2(b_gn2[None, :] - b_g1)) + # [BC, BC] + b_Aqk21 += tl.dot(b_qg2, b_kgt) + b_Akk21 += tl.dot(b_kg2, b_kgt) + + if i_tc3 < T: + p_q3 = tl.make_block_ptr( + q, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + p_k3 = tl.make_block_ptr( + k, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + p_g3 = tl.make_block_ptr( + g, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0) + ) + # [BC, BK] + b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn3 = tl.load(g + i_tc3 * H * K + o_k, mask=m_k, other=0).to( + tl.float32 + ) + # [BC, BK] + b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) + b_qg3 = b_q3 * b_gqn3 + b_kg3 = b_k3 * b_gqn3 + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn3[None, :] - b_g0)) + # [BC, BC] + b_Aqk30 += tl.dot(b_qg3, b_kgt) + b_Akk30 += tl.dot(b_kg3, b_kgt) + # [BK, BC] + b_kgt = tl.trans(b_k1 * exp2(b_gn3[None, :] - b_g1)) + # [BC, BC] + b_Aqk31 += tl.dot(b_qg3, b_kgt) + b_Akk31 += tl.dot(b_kg3, b_kgt) + # [BK, BC] + b_kgt = tl.trans(b_k2 * exp2(b_gn3[None, :] - b_g2)) + # [BC, BC] + b_Aqk32 += tl.dot(b_qg3, b_kgt) + b_Akk32 += tl.dot(b_kg3, b_kgt) + + ################################################################################ + # save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + if i_tc1 < T: + p_Aqk10 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0) + ) + tl.store( + p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + + p_b1 = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,) + ) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_Aqk20 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0) + ) + p_Aqk21 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0) + ) + tl.store( + p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + tl.store( + p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + + p_b2 = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,) + ) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_Aqk30 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0) + ) + p_Aqk31 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0) + ) + p_Aqk32 = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0) + ) + tl.store( + p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + tl.store( + p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + tl.store( + p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1) + ) + + p_b3 = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,) + ) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + p_Akk00 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc0, 0), (BC, BC), (1, 0) + ) + p_Akk11 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc1, 0), (BC, BC), (1, 0) + ) + p_Akk22 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc2, 0), (BC, BC), (1, 0) + ) + p_Akk33 = tl.make_block_ptr( + Akkd, (T, BC), (H * BC, 1), (i_tc3, 0), (BC, BC), (1, 0) + ) + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + ################################################################################ + # forward substitution on diagonals + ################################################################################ + + if not USE_SAFE_GATE: + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Ai00 = -tl.where(m_A, b_Ai00, 0) + b_Ai11 = -tl.where(m_A, b_Ai11, 0) + b_Ai22 = -tl.where(m_A, b_Ai22, 0) + b_Ai33 = -tl.where(m_A, b_Ai33, 0) + + for i in range(2, min(BC, T - i_tc0)): + b_a00 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a00 = tl.where(o_i < i, b_a00, 0.0) + b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(BC + 2, min(2 * BC, T - i_tc0)): + b_a11 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a11 = tl.where(o_i < i - BC, b_a11, 0.0) + b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) + for i in range(2 * BC + 2, min(3 * BC, T - i_tc0)): + b_a22 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a22 = tl.where(o_i < i - 2 * BC, b_a22, 0.0) + b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i - 2 * BC)[:, None], b_a22, b_Ai22) + for i in range(3 * BC + 2, min(4 * BC, T - i_tc0)): + b_a33 = -tl.load(Akkd + (i_tc0 + i) * H * BC + o_i) + b_a33 = tl.where(o_i < i - 3 * BC, b_a33, 0.0) + b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i - 3 * BC)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + ################################################################################ + # compute merged inverse using off-diagonals + ################################################################################ + + # we used tf32 to maintain matrix inverse's precision whenever possible. + b_Ai10 = -tl.dot( + tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai00, + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai21 = -tl.dot( + tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai11, + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai32 = -tl.dot( + tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai22, + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + ################################################################################ + # store full Akk_inv to Akk + ################################################################################ + + p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0) + ) + p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk21 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0) + ) + p_Akk22 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0) + ) + p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Akk31 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0) + ) + p_Akk32 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0) + ) + p_Akk33 = tl.make_block_ptr( + Akk, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0) + ) + + tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_intra_sub_chunk( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_GATHER: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_c = i_ti + tl.arange(0, BC) + m_c = o_c < T + + q = q + (bos * H + i_h) * K + k = k + (bos * H + i_h) * K + g = g + (bos * H + i_h) * K + beta = beta + bos * H + i_h + Aqk = Aqk + (bos * H + i_h) * BT + Akk = Akk + (bos * H + i_h) * BC + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, 0), (BC, BK), (1, 0)) + + p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + if USE_GATHER: + b_gn = gather( + b_g, tl.full([1, BK], min(BC // 2, T - i_ti - 1), dtype=tl.int16), axis=0 + ) + else: + # calculate offset + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H * K + tl.arange(0, BK) + b_gn = tl.load(p_gn, mask=tl.arange(0, BK) < K, other=0.0) + b_gn = b_gn[None, :] + + # current block, keep numerical stability by subtracting the left boundary + # less than 85 to avoid overflow in exp2 + b_gm = (b_g - b_gn).to(tl.float32) + + b_gq = tl.where(m_c[:, None], exp2(b_gm), 0.0) + b_gk = tl.where(m_c[:, None], exp2(-b_gm), 0.0) + + b_kgt = tl.trans(b_k * b_gk) + + b_Aqk = tl.dot(b_q * b_gq, b_kgt) * scale + b_Akk = tl.dot(b_k * b_gq, b_kgt) * b_beta[:, None] + + o_i = tl.arange(0, BC) + m_Aqk = o_i[:, None] >= o_i[None, :] + m_Akk = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Aqk = tl.where(m_Aqk, b_Aqk, 0.0) + b_Akk = tl.where(m_Akk, b_Akk, 0.0) + + p_Aqk = tl.make_block_ptr( + Aqk, (T, BT), (H * BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0) + ) + p_Akk = tl.make_block_ptr(Akk, (T, BC), (H * BC, 1), (i_ti, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk, b_Akk.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + + ################################################################################ + # forward substitution + ################################################################################ + + b_Ai = -b_Akk + for i in range(2, min(BC, T - i_ti)): + b_a = -tl.load(Akk + (i_ti + i) * H * BC + o_i) + b_a = tl.where(o_i < i, b_a, 0.0) + b_a += tl.sum(b_a[:, None] * b_Ai, 0) + b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai) + b_Ai += m_I + tl.store(p_Akk, b_Ai.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor | None = None, + safe_gate: bool = False, + disable_recompute: bool = False, +): + B, T, H, K = k.shape + BT = chunk_size + BC = 16 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + + Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) + # Akk must be zero-initialized - kernel only writes lower triangular + Akk = torch.zeros(B, T, H, BT, device=k.device, dtype=k.dtype) + # Separate fp32 buffer for diagonal 16x16 blocks (for precision in solve_tril) + Akkd = torch.zeros(B, T, H, BC, device=k.device, dtype=torch.float32) + + # Step 1: Run token_parallel first to compute diagonal blocks into Akkd (fp32) + # Step 1: compute diagonal blocks into Akk_diag (fp32) + if safe_gate: + grid = (NT, NC, B * H) + BK = triton.next_power_of_2(K) + chunk_kda_fwd_kernel_intra_sub_chunk[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akkd, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + USE_GATHER=is_gather_supported, + ) + else: + Aqk, Akkd = chunk_kda_fwd_intra_token_parallel( + q=q, + k=k, + gk=gk, + beta=beta, + Aqk=Aqk, + Akk=Akkd, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + sub_chunk_size=BC, + ) + + # Step 2: Fused inter + solve_tril (works for both fixed-len and varlen) + grid = (NT, B * H) + chunk_kda_fwd_kernel_inter_solve_fused[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akkd=Akkd, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + USE_SAFE_GATE=safe_gate, + ) + from sglang.srt.layers.attention.fla.kda import ( + recompute_w_u_fwd as kda_recompute_w_u_fwd, + ) + + w, u, qg, kg = kda_recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=Akk, + q=q if disable_recompute else None, + gk=gk, + cu_seqlens=cu_seqlens, + ) + return w, u, qg, kg, Aqk, Akk diff --git a/python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py b/python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py new file mode 100644 index 000000000000..ec8bc848c839 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/chunk_intra_token_parallel.py @@ -0,0 +1,197 @@ +# Adapted from flash-linear-attention project. +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Token-parallel implementation of KDA intra chunk kernel + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.fla.op import exp2 +from sglang.srt.layers.attention.fla.utils import autotune_cache_kwargs + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BH": BH}, num_warps=num_warps) + for BH in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8] + ], + key=["K", "H"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T", "N"]) +def chunk_kda_fwd_kernel_intra_token_parallel( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + N, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BH: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_tg, i_hg = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = 0 + left, right = 0, N + + # Unrolled binary search (max B=2^32) + # We can limit iterations based on expected max batch size if needed + # 20 iterations covers B=1M, usually enough + for _ in range(20): + if left < right: + mid = (left + right) // 2 + if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): + right = mid + else: + left = mid + 1 + i_n = left + + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + i_t = i_tg - bos + else: + bos = (i_tg // T) * T + i_t = i_tg % T + + if i_t >= T: + return + + i_c = i_t // BT + i_s = (i_t % BT) // BC + i_tc = i_c * BT + i_ts = i_tc + i_s * BC + + q += bos * H * K + k += bos * H * K + g += bos * H * K + Aqk += bos * H * BT + Akk += bos * H * BC + beta += bos * H + + o_h = tl.arange(0, BH) + o_k = tl.arange(0, BK) + m_h = (i_hg * BH + o_h) < H + m_k = o_k < K + + p_q = tl.make_block_ptr( + q + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_g = tl.make_block_ptr( + g + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_beta = tl.make_block_ptr(beta + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) + # [BH, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] + + for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): + p_kj = tl.make_block_ptr( + k + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + p_gj = tl.make_block_ptr( + g + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0) + ) + # [BH, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) + b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) + + b_kgj = b_kj * exp2(b_g - b_gj) + + b_kgj = tl.where(m_k[None, :], b_kgj, 0.0) + # [BH] + b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale + b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) + + tl.store( + Aqk + i_t * H * BT + (i_hg * BH + o_h) * BT + j % BT, + b_Aqk.to(Aqk.dtype.element_ty), + mask=m_h, + ) + tl.store( + Akk + i_t * H * BC + (i_hg * BH + o_h) * BC + j - i_ts, + b_Akk.to(Akk.dtype.element_ty), + mask=m_h, + ) + + +def chunk_kda_fwd_intra_token_parallel( + q: torch.Tensor, + k: torch.Tensor, + gk: torch.Tensor, + beta: torch.Tensor, + Aqk: torch.Tensor, + Akk: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + sub_chunk_size: int = 16, +) -> None: + """ + Token-parallel implementation: each token gets its own thread block. + Supports both fixed-length and variable-length sequences. + Reduces wasted computation on padding. + + Writes directly to Aqk and Akk tensors (in-place). + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + gk: [B, T, H, K] cumsum of gates + beta: [B, T, H] + Aqk: [B, T, H, BT] output tensor to write to + Akk: [B, T, H, BC] output tensor for diagonal blocks (fp32) + scale: attention scale + chunk_size: BT (default 64) + sub_chunk_size: BC (default 16) + """ + B, T, H, K = q.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BT = chunk_size + BC = sub_chunk_size + + def grid(meta): + return (B * T, triton.cdiv(H, meta["BH"])) + + BK = triton.next_power_of_2(K) + + chunk_kda_fwd_kernel_intra_token_parallel[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + N=N, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + return Aqk, Akk diff --git a/python/sglang/srt/layers/attention/fla/kda.py b/python/sglang/srt/layers/attention/fla/kda.py index 3f17b21cce54..a8d5cb405ea9 100644 --- a/python/sglang/srt/layers/attention/fla/kda.py +++ b/python/sglang/srt/layers/attention/fla/kda.py @@ -9,6 +9,7 @@ import triton.language as tl from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h +from sglang.srt.layers.attention.fla.chunk_intra import chunk_kda_fwd_intra from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum from sglang.srt.layers.attention.fla.fused_norm_gate import layer_norm_gated_fwd from sglang.srt.layers.attention.fla.fused_recurrent import ( @@ -17,7 +18,6 @@ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd from sglang.srt.layers.attention.fla.op import exp, log -from sglang.srt.layers.attention.fla.solve_tril import solve_tril from sglang.srt.layers.attention.fla.utils import is_amd BT_LIST_AUTOTUNE = [32, 64, 128] @@ -863,27 +863,19 @@ def chunk_kda_fwd( ): chunk_size = 64 g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) - # the intra Aqk is kept in fp32 - # the computation has very marginal effect on the entire throughput - A, Aqk = chunk_kda_scaled_dot_kkt_fwd( + + # Fused: scaled_dot_kkt + solve_tril + recompute_w_u + w, u, _, kg, Aqk, _ = chunk_kda_fwd_intra( q=q, k=k, + v=v, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, - output_dtype=torch.float32, - ) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) - w, u, _, kg = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - gk=g, - cu_seqlens=cu_seqlens, + chunk_size=chunk_size, ) - del A + h, v_new = chunk_gated_delta_rule_fwd_h( k=kg, w=w,