diff --git a/python/sglang/srt/layers/attention/fla/chunk_delta_h.py b/python/sglang/srt/layers/attention/fla/chunk_delta_h.py index 38a7c8f297e3..82e490426b47 100644 --- a/python/sglang/srt/layers/attention/fla/chunk_delta_h.py +++ b/python/sglang/srt/layers/attention/fla/chunk_delta_h.py @@ -13,7 +13,14 @@ prepare_chunk_offsets, ) from sglang.srt.layers.attention.fla.op import exp, safe_exp -from sglang.srt.layers.attention.fla.utils import is_nvidia_hopper +from sglang.srt.layers.attention.fla.utils import IS_GLUON_SUPPORTED, is_nvidia_hopper + +if IS_GLUON_SUPPORTED: + from sglang.srt.layers.attention.fla.gluon import TensorDescriptor, gl + from sglang.srt.layers.attention.fla.gluon.chunk_delta_h_gluon import ( + chunk_gated_delta_rule_fwd_kernel_h_blockdim64_gluon, + ) + NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] CHUNK_SIZE = 64 @@ -55,6 +62,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( INPLACE_UPDATE: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr, IS_VARLEN: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, ): i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H @@ -101,23 +109,47 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( # load initial state if USE_INITIAL_STATE: - p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) - b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) - if K > 64: - p_h0_2 = tl.make_block_ptr( - h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + if not TRANSPOSE_STATE: + p_h0_1 = tl.make_block_ptr( + h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0) ) - b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) - if K > 128: - p_h0_3 = tl.make_block_ptr( - h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) - ) - b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) - if K > 192: - p_h0_4 = tl.make_block_ptr( - h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr( + h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr( + h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr( + h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + else: + # Column major: shape (K, V), stride (1, K) + p_h0_1 = tl.make_block_ptr( + h0, (K, V), (1, K), (0, i_v * BV), (64, BV), (0, 1) ) - b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr( + h0, (K, V), (1, K), (64, i_v * BV), (64, BV), (0, 1) + ) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr( + h0, (K, V), (1, K), (128, i_v * BV), (64, BV), (0, 1) + ) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr( + h0, (K, V), (1, K), (192, i_v * BV), (64, BV), (0, 1) + ) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) # main recurrence for i_t in range(NT): @@ -252,23 +284,47 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( # epilogue if INPLACE_UPDATE: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) - tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - if K > 64: - p_ht = tl.make_block_ptr( - ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) - ) - tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - if K > 128: + if not TRANSPOSE_STATE: p_ht = tl.make_block_ptr( - ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0) ) - tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - if K > 192: + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + else: + # Column major: shape (K, V), stride (1, K) p_ht = tl.make_block_ptr( - ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ht, (K, V), (1, K), (0, i_v * BV), (64, BV), (0, 1) ) - tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr( + ht, (K, V), (1, K), (64, i_v * BV), (64, BV), (0, 1) + ) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr( + ht, (K, V), (1, K), (128, i_v * BV), (64, BV), (0, 1) + ) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr( + ht, (K, V), (1, K), (192, i_v * BV), (64, BV), (0, 1) + ) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) def chunk_gated_delta_rule_fwd_h( @@ -281,6 +337,7 @@ def chunk_gated_delta_rule_fwd_h( initial_state_indices: Optional[torch.Tensor] = None, save_new_value: bool = True, cu_seqlens: Optional[torch.LongTensor] = None, + transpose_state: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, u.shape[-1] H = u.shape[-2] @@ -306,6 +363,97 @@ def chunk_gated_delta_rule_fwd_h( v_new = torch.empty_like(u) if save_new_value else None + if IS_GLUON_SUPPORTED: + BK = K + BV = 64 + IS_VARLEN = cu_seqlens is not None + num_warps = 4 + + gl_dtype = gl.bfloat16 if k.dtype == torch.bfloat16 else gl.float16 + # k, w: [B, T, HK, K] / [B, T, H, K] + kw_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BK], gl_dtype) + # v, v_new: [B, T, H, V] + v_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BV], gl_dtype) + # h: [B, NT, H, K, V] + h_layout = gl.NVMMASharedLayout.get_default_for([1, 1, 1, BK, BV], gl_dtype) + + k_desc = TensorDescriptor.from_tensor(k, [1, BT, 1, BK], kw_layout) + w_desc = TensorDescriptor.from_tensor(w, [1, BT, 1, BK], kw_layout) + v_desc = TensorDescriptor.from_tensor(u, [1, BT, 1, BV], v_layout) + h_desc = TensorDescriptor.from_tensor(h, [1, 1, 1, BK, BV], h_layout) + + if initial_state is not None: + if transpose_state: + # transpose_state=True: state is stored in [V, K] physical layout. + # We need a view with shape [..., V, K] that is K-contiguous for TMA alignment. + if initial_state.stride(-1) == 1: + # Already K-contiguous [..., V, K], use directly + h0_view = initial_state + else: + # [..., K, V] with K-stride==1 -> transpose to [..., V, K] + h0_view = initial_state.transpose(-2, -1) + h0_layout = gl.NVMMASharedLayout.get_default_for( + [1, 1, BV, BK], gl.float32 + ) + h0_desc = TensorDescriptor.from_tensor( + h0_view, [1, 1, BV, BK], h0_layout + ) + else: + h0_layout = gl.NVMMASharedLayout.get_default_for( + [1, 1, BK, BV], gl.float32 + ) + h0_desc = TensorDescriptor.from_tensor( + initial_state, [1, 1, BK, BV], h0_layout + ) + else: + h0_desc = None + + # For varlen, use scatter layout for v_new to handle boundary correctly + if save_new_value: + if IS_VARLEN: + v_new_scatter_layout = gl.NVMMASharedLayout.get_default_for( + [BT, BV], gl_dtype + ) + v_new_desc = TensorDescriptor.from_tensor( + v_new.view(B * T, H * V), [1, BV], v_new_scatter_layout + ) + else: + v_new_desc = TensorDescriptor.from_tensor( + v_new, [1, BT, 1, BV], v_layout + ) + else: + v_new_desc = None + + grid = (triton.cdiv(V, BV), N * H) + chunk_gated_delta_rule_fwd_kernel_h_blockdim64_gluon[grid]( + k_desc=k_desc, + v_desc=v_desc, + w_desc=w_desc, + v_new_desc=v_new_desc, + g=g, + h_desc=h_desc, + h0_desc=h0_desc, + initial_state_indices=initial_state_indices, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HK=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_INITIAL_STATE=initial_state is not None, + INPLACE_UPDATE=initial_state is not None, + SAVE_NEW_VALUE=v_new is not None, + IS_VARLEN=cu_seqlens is not None, + TRANSPOSE_STATE=transpose_state, + num_warps=num_warps, + ) + return h, v_new + def grid(meta): return (triton.cdiv(V, meta["BV"]), N * H) @@ -334,6 +482,7 @@ def grid(meta): INPLACE_UPDATE=True, SAVE_NEW_VALUE=v_new is not None, IS_VARLEN=cu_seqlens is not None, + TRANSPOSE_STATE=transpose_state, num_warps=4, num_stages=2, ) diff --git a/python/sglang/srt/layers/attention/fla/chunk_o.py b/python/sglang/srt/layers/attention/fla/chunk_o.py index bb89421eb872..4c29b887053e 100644 --- a/python/sglang/srt/layers/attention/fla/chunk_o.py +++ b/python/sglang/srt/layers/attention/fla/chunk_o.py @@ -10,7 +10,18 @@ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices from sglang.srt.layers.attention.fla.op import exp, safe_exp -from sglang.srt.layers.attention.fla.utils import check_shared_mem, is_nvidia_hopper +from sglang.srt.layers.attention.fla.utils import ( + IS_GLUON_SUPPORTED, + check_shared_mem, + is_nvidia_hopper, +) + +if IS_GLUON_SUPPORTED: + from sglang.srt.layers.attention.fla.gluon import TensorDescriptor, gl + from sglang.srt.layers.attention.fla.gluon.chunk_o_gluon import ( + chunk_fwd_kernel_o_gluon, + ) + BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] @@ -135,7 +146,7 @@ def chunk_fwd_o( ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] - BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BT = chunk_size chunk_indices = ( prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) @@ -145,30 +156,78 @@ def chunk_fwd_o( o = torch.zeros_like(v) - def grid(meta): - return (triton.cdiv(V, meta["BV"]), NT, B * H) - - chunk_fwd_kernel_o[grid]( - q, - k, - v, - h, - g, - o, - cu_seqlens, - chunk_indices, - scale, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - BK=128, - BV=64, - USE_G=g is not None, - IS_VARLEN=cu_seqlens is not None, - num_warps=4, - num_stages=2, - ) + if IS_GLUON_SUPPORTED: + BK = 128 if K >= 128 else 64 + BV = 128 if V >= 128 else 64 + IS_VARLEN = cu_seqlens is not None + num_warps = 8 if BT >= 128 else 4 + + gl_dtype = gl.bfloat16 if q.dtype == torch.bfloat16 else gl.float16 + qk_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BK], gl_dtype) + vo_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BV], gl_dtype) + h_layout = gl.NVMMASharedLayout.get_default_for([1, 1, 1, BK, BV], gl_dtype) + q_desc = TensorDescriptor.from_tensor(q, [1, BT, 1, BK], qk_layout) + k_desc = TensorDescriptor.from_tensor(k, [1, BT, 1, BK], qk_layout) + v_desc = TensorDescriptor.from_tensor(v, [1, BT, 1, BV], vo_layout) + h_desc = TensorDescriptor.from_tensor(h, [1, 1, 1, BK, BV], h_layout) + if IS_VARLEN: + o_layout = gl.NVMMASharedLayout.get_default_for([BT, BV], gl_dtype) + o_desc = TensorDescriptor.from_tensor( + o.view(B * T, H * V), [1, BV], o_layout + ) + else: + o_desc = TensorDescriptor.from_tensor(o, [1, BT, 1, BV], vo_layout) + + grid = (triton.cdiv(V, BV), NT, B * H) + chunk_fwd_kernel_o_gluon[grid]( + q_desc=q_desc, + k_desc=k_desc, + v_desc=v_desc, + h_desc=h_desc, + o_desc=o_desc, + g=g, + g_gamma=None, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + HK=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_G=g is not None, + IS_VARLEN=IS_VARLEN, + num_warps=num_warps, + ) + else: + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=64, + USE_G=g is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=2, + ) return o diff --git a/python/sglang/srt/layers/attention/fla/cumsum.py b/python/sglang/srt/layers/attention/fla/cumsum.py index 39d2f4722778..0e237cd6d233 100644 --- a/python/sglang/srt/layers/attention/fla/cumsum.py +++ b/python/sglang/srt/layers/attention/fla/cumsum.py @@ -9,11 +9,77 @@ import triton.language as tl from sglang.srt.layers.attention.fla.index import prepare_chunk_indices -from sglang.srt.layers.attention.fla.utils import check_shared_mem, input_guard +from sglang.srt.layers.attention.fla.utils import ( + FLA_CUMSUM_SCALAR_VECTORIZATION, + check_shared_mem, + input_guard, +) BS_LIST = [32, 64] if check_shared_mem() else [16, 32] +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_vectorization_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + BH: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + n_groups = tl.cdiv(H, BH) + i_b, i_hg = i_bh // n_groups, i_bh % n_groups + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load( + chunk_indices + i_t * 2 + 1 + ).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load( + cu_seqlens + i_n + 1 + ).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + bos * H, (H, T), (T, 1), (i_hg * BH, i_t * BT), (BH, BT), (1, 0) + ) + p_o = tl.make_block_ptr( + o + bos * H, (H, T), (T, 1), (i_hg * BH, i_t * BT), (BH, BT), (1, 0) + ) + else: + p_s = tl.make_block_ptr( + s + bos * H, (T, H), (H, 1), (i_t * BT, i_hg * BH), (BT, BH), (1, 0) + ) + p_o = tl.make_block_ptr( + o + bos * H, (T, H), (H, 1), (i_t * BT, i_hg * BH), (BT, BH), (1, 0) + ) + # [BT, BH] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if HEAD_FIRST: + b_o = tl.cumsum(b_s, axis=1) + if REVERSE: + b_z = tl.sum(b_s, axis=1) + b_o = -b_o + b_z[:, None] + b_s + else: + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None, :] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + # @triton.autotune( # configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], # key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], @@ -178,23 +244,45 @@ def chunk_local_cumsum_scalar( NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (NT, B * H) - chunk_local_cumsum_scalar_kernel[grid]( - s=g_org, - o=g, - scale=scale, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - B=B, - H=H, - BT=BT, - HEAD_FIRST=head_first, - REVERSE=reverse, - HAS_SCALE=scale is not None, - IS_VARLEN=cu_seqlens is not None, - num_warps=8, - num_stages=3, - ) + if not FLA_CUMSUM_SCALAR_VECTORIZATION: + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + HAS_SCALE=scale is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=8, + num_stages=3, + ) + else: + BH = min(8, triton.next_power_of_2(H)) + grid = (NT, B * triton.cdiv(H, BH)) + chunk_local_cumsum_scalar_vectorization_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + BH=BH, + HEAD_FIRST=head_first, + REVERSE=reverse, + HAS_SCALE=scale is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=3, + ) return g diff --git a/python/sglang/srt/layers/attention/fla/gluon/__init__.py b/python/sglang/srt/layers/attention/fla/gluon/__init__.py new file mode 100644 index 000000000000..84d5bea0419a --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/gluon/__init__.py @@ -0,0 +1,42 @@ +import torch +import triton + +try: + from triton.experimental import gluon + from triton.experimental.gluon import language as gl + from triton.experimental.gluon.language.nvidia.blackwell import ( + TensorMemoryLayout, + allocate_tensor_memory, + fence_async_shared, + get_tmem_reg_layout, + mbarrier, + tcgen05_commit, + tcgen05_mma, + tma, + ) + from triton.experimental.gluon.nvidia.hopper import TensorDescriptor +except ImportError as e: + raise ImportError( + f">>> Failed to import Gluon in current triton version {triton.__version__} and " + f">>> Platform {torch.cuda.get_device_capability()}.\n" + f">>> Gluon/Blackwell features require: \n" + f">>> 1. Triton >= 3.6.0 \n" + f">>> 2. NVIDIA GPU (compute capability == 10.0)\n" + f">>> 3. Pytorch >= 2.9.0 \n" + f">>> Error: {e}\n" + f">>> Set FLA_USE_GLUON=0 to disable and continue." + ) from e + +__all__ = [ + "gluon", + "gl", + "TensorMemoryLayout", + "allocate_tensor_memory", + "fence_async_shared", + "get_tmem_reg_layout", + "mbarrier", + "tcgen05_commit", + "tcgen05_mma", + "tma", + "TensorDescriptor", +] diff --git a/python/sglang/srt/layers/attention/fla/gluon/chunk_delta_h_gluon.py b/python/sglang/srt/layers/attention/fla/gluon/chunk_delta_h_gluon.py new file mode 100644 index 000000000000..16ccb478c84d --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/gluon/chunk_delta_h_gluon.py @@ -0,0 +1,293 @@ +from sglang.srt.layers.attention.fla.gluon import ( + TensorMemoryLayout, + allocate_tensor_memory, + fence_async_shared, + get_tmem_reg_layout, + gl, + gluon, + mbarrier, + tcgen05_commit, + tcgen05_mma, + tma, +) + + +@gluon.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64_gluon( + k_desc, + v_desc, + w_desc, + v_new_desc, + g, + h_desc, + h0_desc, + initial_state_indices, + cu_seqlens, + chunk_offsets, + T, + H: gl.constexpr, + HK: gl.constexpr, + K: gl.constexpr, + V: gl.constexpr, + BT: gl.constexpr, + BK: gl.constexpr, + BV: gl.constexpr, + USE_G: gl.constexpr, + USE_INITIAL_STATE: gl.constexpr, + INPLACE_UPDATE: gl.constexpr, + SAVE_NEW_VALUE: gl.constexpr, + IS_VARLEN: gl.constexpr, + TRANSPOSE_STATE: gl.constexpr, +): + i_v, i_nh = gl.program_id(0), gl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + i_hk = i_h // (H // HK) + + if IS_VARLEN: + bos, eos = gl.load(cu_seqlens + i_n).to(gl.int32), gl.load( + cu_seqlens + i_n + 1 + ).to(gl.int32) + T = eos - bos + NT = gl.cdiv(T, BT) + boh = gl.load(chunk_offsets + i_n).to(gl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = gl.cdiv(T, BT) + boh = i_n * NT + + index = gl.load(initial_state_indices + i_n).to(gl.int32) + NUM_WARPS: gl.constexpr = gl.num_warps() + + # Allocate shared memory for TMA loads + dtype: gl.constexpr = k_desc.dtype + # k, w: [1, BT, 1, BK] + k_smem = gl.allocate_shared_memory(dtype, k_desc.block_type.shape, k_desc.layout) + w_smem = gl.allocate_shared_memory(dtype, w_desc.block_type.shape, w_desc.layout) + # v: [1, BT, 1, BV] + v_smem = gl.allocate_shared_memory(dtype, v_desc.block_type.shape, v_desc.layout) + # h0/ht: [1, 1, BK, BV] or [1, 1, BV, BK] if TRANSPOSE_STATE, dtype=fp32 (4D TMA) + h0_smem = ( + gl.allocate_shared_memory(gl.float32, h0_desc.block_type.shape, h0_desc.layout) + if USE_INITIAL_STATE + else None + ) + # h: [1, 1, 1, BK, BV], dtype=bf16/fp16 (5D TMA) + h_smem = gl.allocate_shared_memory(dtype, h_desc.block_type.shape, h_desc.layout) + + # For varlen: use scatter layout [BT, BV]; for non-varlen: use [1, BT, 1, BV] + if SAVE_NEW_VALUE: + if IS_VARLEN: + offsets_layout: gl.constexpr = gl.SliceLayout( + 0, gl.BlockedLayout([1, 4], [32, 1], [1, NUM_WARPS], [1, 0]) + ) + v_new_scatter_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for( + [BT, BV], dtype + ) + v_new_smem = gl.allocate_shared_memory( + dtype, [BT, BV], v_new_scatter_layout + ) + else: + v_new_smem = gl.allocate_shared_memory( + dtype, v_new_desc.block_type.shape, v_new_desc.layout + ) + else: + v_new_smem = None + + # Allocate mbarriers for TMA synchronization + tma_bar_k = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar_k, count=1) + tma_bar_w = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar_w, count=1) + tma_bar_v = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar_v, count=1) + mma_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(mma_bar, count=1) + tma_phase_k = 0 + tma_phase_w = 0 + tma_phase_v = 0 + mma_phase = 0 + + # Tensor memory layout for accumulation + v_tmem_layout: gl.constexpr = TensorMemoryLayout([BT, BV], col_stride=1) + h_tmem_layout: gl.constexpr = TensorMemoryLayout([BK, BV], col_stride=1) + v_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, [BT, BV], v_tmem_layout, NUM_WARPS + ) + v_reg_layout_16: gl.constexpr = get_tmem_reg_layout( + dtype, [BT, BV], v_tmem_layout, NUM_WARPS + ) + h_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, [BK, BV], h_tmem_layout, NUM_WARPS + ) + g_layout_bt: gl.constexpr = gl.SliceLayout(dim=1, parent=v_reg_layout) + + # Allocate tensor memory for MMA operations + v_tmem = allocate_tensor_memory(gl.float32, [BT, BV], v_tmem_layout) + kv_tmem = allocate_tensor_memory(gl.float32, [BK, BV], h_tmem_layout) + + # Initialize h accumulators + b_h = gl.zeros([BK, BV], dtype=gl.float32, layout=h_reg_layout) + + # Prologue: prefetch w[0] early (overlap with h0 load + transpose) + mbarrier.expect(tma_bar_w, w_desc.block_type.nbytes) + if IS_VARLEN: + tma.async_copy_global_to_shared(w_desc, [0, bos, i_h, 0], tma_bar_w, w_smem) + else: + tma.async_copy_global_to_shared(w_desc, [i_n, 0, i_h, 0], tma_bar_w, w_smem) + + # Load initial state + if USE_INITIAL_STATE: + tma_bar_h0 = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar_h0, count=1) + tma_phase_h0 = 0 + mbarrier.expect(tma_bar_h0, h0_desc.block_type.nbytes) + if TRANSPOSE_STATE: + tma.async_copy_global_to_shared( + h0_desc, [index, i_h, i_v * BV, 0], tma_bar_h0, h0_smem + ) + else: + tma.async_copy_global_to_shared( + h0_desc, [index, i_h, 0, i_v * BV], tma_bar_h0, h0_smem + ) + mbarrier.wait(tma_bar_h0, phase=tma_phase_h0) + tma_phase_h0 ^= 1 + if TRANSPOSE_STATE: + # smem permute: load [BV,BK] smem as [BK,BV] via permuted view + h0_smem_2d = h0_smem.reshape([BV, BK]) + h0_smem_t = h0_smem_2d.permute((1, 0)) # [BK, BV] view + b_h0 = h0_smem_t.load(h_reg_layout) + else: + h0_smem_2d = h0_smem.reshape([BK, BV]) + b_h0 = h0_smem_2d.load(h_reg_layout) + b_h = b_h + b_h0 + mbarrier.invalidate(tma_bar_h0) + + # Main Loop + for i_t in range(NT): + if IS_VARLEN: + i_b, i_t_h, i_t_kvw = 0, boh + i_t, bos + i_t * BT + else: + i_b, i_t_h, i_t_kvw = i_n, i_t, i_t * BT + + # Prefetch v and k early (max overlap with gate + h_store + w_wait + MMA1) + mbarrier.expect(tma_bar_v, v_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + v_desc, [i_b, i_t_kvw, i_h, i_v * BV], tma_bar_v, v_smem + ) + mbarrier.expect(tma_bar_k, k_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + k_desc, [i_b, i_t_kvw, i_hk, 0], tma_bar_k, k_smem + ) + + # Compute gating values (scalar ops, overlap with TMA in-flight) + if USE_G: + last_idx = T - 1 if i_t == NT - 1 else (i_t + 1) * BT - 1 + bg_last = gl.load(g + (bos + last_idx) * H + i_h) + g_offset = i_t * BT + gl.arange(0, BT, layout=g_layout_bt) + g_mask = g_offset < T + b_g = gl.load(g + (bos + g_offset) * H + i_h, mask=g_mask, other=0) + bg_last_exp = gl.exp(bg_last) + + if SAVE_NEW_VALUE and IS_VARLEN: + t_limit_right = gl.minimum(T - i_t * BT, BT) + t_offsets = gl.arange(0, BT, layout=offsets_layout) + row_valid = t_offsets < t_limit_right + x_offsets = gl.where(row_valid, bos + i_t * BT + t_offsets, 0x7FFFFFFF) + + # Store h_i to smem + h_smem_2d = h_smem.reshape([BK, BV]) + h_smem_2d.store(b_h.to(dtype)) + + # Wait for w (prefetched in prologue or previous iteration) + mbarrier.wait(tma_bar_w, phase=tma_phase_w) + tma_phase_w ^= 1 + w_smem_2d = w_smem.reshape([BT, BK]) + + # TMA store h to global + fence_async_shared() + tma.async_copy_shared_to_global(h_desc, [i_b, i_t_h, i_h, 0, i_v * BV], h_smem) + # w @ h: [BT, BK] @ [BK, BV] -> [BT, BV] + tcgen05_mma(w_smem_2d, h_smem_2d, v_tmem, use_acc=False) + tcgen05_commit(mma_bar) + mbarrier.wait(mma_bar, phase=mma_phase) + mma_phase ^= 1 + + # Prefetch w for next iteration (w_smem is free after MMA1 completes) + if i_t < NT - 1: + mbarrier.expect(tma_bar_w, w_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + w_desc, [i_b, i_t_kvw + BT, i_h, 0], tma_bar_w, w_smem + ) + + v_acc_reg = v_tmem.load(v_reg_layout) + mbarrier.wait(tma_bar_v, phase=tma_phase_v) + tma_phase_v ^= 1 + v_smem_2d = v_smem.reshape([BT, BV]) + v_reg = v_smem_2d.load(v_reg_layout_16) + v_new_reg = v_reg - v_acc_reg + + # store v_new to global + if SAVE_NEW_VALUE: + if IS_VARLEN: + v_new_smem.store(v_new_reg.to(dtype)) + fence_async_shared() + tma.async_scatter(v_new_desc, x_offsets, i_h * V + i_v * BV, v_new_smem) + else: + v_new_smem_2d = v_new_smem.reshape([BT, BV]) + v_new_smem_2d.store(v_new_reg.to(dtype)) + fence_async_shared() + tma.async_copy_shared_to_global( + v_new_desc, [i_b, i_t_kvw, i_h, i_v * BV], v_new_smem + ) + + # Apply gating + if USE_G: + if i_t == NT - 1: + v_new_reg = ( + v_new_reg * gl.where(g_mask, gl.exp(bg_last - b_g), 0)[:, None] + ) + else: + v_new_reg = v_new_reg * gl.exp(bg_last - b_g)[:, None] + b_h *= bg_last_exp + + # Store gated v_new back to v_smem + v_new_reg = v_new_reg.to(dtype) + v_smem_2d.store(v_new_reg) + + # Wait for k + mbarrier.wait(tma_bar_k, phase=tma_phase_k) + tma_phase_k ^= 1 + k_smem_2d = k_smem.reshape([BT, BK]) + k_t = k_smem_2d.permute((1, 0)) + + # fence v + fence_async_shared() + # k.T @ v_new -> kv_tmem: [BK, BT] @ [BT, BV] -> [BK, BV] + tcgen05_mma(k_t, v_smem_2d, kv_tmem, use_acc=False) + tcgen05_commit(mma_bar) + mbarrier.wait(mma_bar, phase=mma_phase) + mma_phase ^= 1 + + # h_i += k_i.T @ v_new + b_kv = kv_tmem.load(h_reg_layout) + b_h = b_h + b_kv + + if INPLACE_UPDATE: + if TRANSPOSE_STATE: + # smem permute: store [BK,BV] reg to [BV,BK] smem via permuted view + h0_smem_2d = h0_smem.reshape([BV, BK]) + h0_smem_t = h0_smem_2d.permute((1, 0)) # [BK, BV] view + h0_smem_t.store(b_h) + fence_async_shared() + tma.async_copy_shared_to_global(h0_desc, [index, i_h, i_v * BV, 0], h0_smem) + else: + h0_smem_2d = h0_smem.reshape([BK, BV]) + h0_smem_2d.store(b_h) + fence_async_shared() + tma.async_copy_shared_to_global(h0_desc, [index, i_h, 0, i_v * BV], h0_smem) + + mbarrier.invalidate(tma_bar_k) + mbarrier.invalidate(tma_bar_w) + mbarrier.invalidate(tma_bar_v) + mbarrier.invalidate(mma_bar) + tma.store_wait(pendings=0) diff --git a/python/sglang/srt/layers/attention/fla/gluon/chunk_o_gluon.py b/python/sglang/srt/layers/attention/fla/gluon/chunk_o_gluon.py new file mode 100644 index 000000000000..d0bff1b7535a --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/gluon/chunk_o_gluon.py @@ -0,0 +1,210 @@ +from sglang.srt.layers.attention.fla.gluon import ( + TensorMemoryLayout, + allocate_tensor_memory, + fence_async_shared, + get_tmem_reg_layout, + gl, + gluon, + mbarrier, + tcgen05_commit, + tcgen05_mma, + tma, +) + + +@gluon.jit +def _mask_scalar(A, col_limit_right, s, i): + col_lim_right_s = col_limit_right - s + col_lim_right_cur = max(col_lim_right_s, 0) + mask = -1 << col_lim_right_cur + mask_i_bit = (mask & (1 << i)) == 0 + return gl.where(mask_i_bit, A, 0.0) + + +@gluon.jit +def _apply_causal_mask(A, col_limit_right): + # Apply causal mask via a bitmask calculated for each block of 16 elements. + # This allows the efficient R2P (register to predicate) instruction to be used at the SASS level. + # ref https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78 + offs_n = gl.arange(0, A.shape[1])[None, :] + s = offs_n & ~0xF + i = offs_n & 0xF + return gl.map_elementwise(_mask_scalar, A, col_limit_right, s, i) + + +@gluon.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o_gluon( + q_desc, + k_desc, + v_desc, + h_desc, + o_desc, + g, + g_gamma, + cu_seqlens, + chunk_indices, + scale, + T, + H: gl.constexpr, + HK: gl.constexpr, + K: gl.constexpr, + V: gl.constexpr, + BT: gl.constexpr, + BK: gl.constexpr, + BV: gl.constexpr, + USE_G: gl.constexpr, + IS_VARLEN: gl.constexpr, + num_warps: gl.constexpr, +): + i_v, i_t, i_bh = gl.program_id(0), gl.program_id(1), gl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_hk = i_h // (H // HK) + i_tg = i_t + if IS_VARLEN: + # global chunk id + # sequence id, chunk id of the current sequence + i_n, i_t = gl.load(chunk_indices + i_t * 2).to(gl.int32), gl.load( + chunk_indices + i_t * 2 + 1 + ).to(gl.int32) + bos, eos = gl.load(cu_seqlens + i_n).to(gl.int32), gl.load( + cu_seqlens + i_n + 1 + ).to(gl.int32) + T = eos - bos + # TMA coordinate: qkvo=[0, bos+i_t*BT, ...], h=[0, i_tg, ...] + i_b, i_t_start = 0, bos + i_t * BT + else: + NT = gl.cdiv(T, BT) + bos, eos = i_b * T, i_b * T + T + # TMA coordinate: qkvo=[i_b, i_t*BT, ...], h=[i_b, i_t, ...] + i_b, i_t_start = i_b, i_t * BT + dtype: gl.constexpr = q_desc.dtype + q_smem = gl.allocate_shared_memory(dtype, q_desc.block_type.shape, q_desc.layout) + k_smem = gl.allocate_shared_memory(dtype, k_desc.block_type.shape, k_desc.layout) + h_smem = gl.allocate_shared_memory(dtype, h_desc.block_type.shape, h_desc.layout) + tma_bar_qh = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar_qh, count=1) + tma_bar_kv = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar_kv, count=1) + mma_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(mma_bar, count=1) + tma_phase = 0 + mma_phase = 0 + o_tmem_layout: gl.constexpr = TensorMemoryLayout([BT, BV], col_stride=1) + o_tmem = allocate_tensor_memory(gl.float32, [BT, BV], o_tmem_layout) + A_tmem_layout: gl.constexpr = TensorMemoryLayout([BT, BT], col_stride=1) + A_tmem = allocate_tensor_memory(gl.float32, [BT, BT], A_tmem_layout) + use_acc = False + for i_k in range(gl.cdiv(K, BK)): + # Load q and h for o computation + mbarrier.expect(tma_bar_qh, q_desc.block_type.nbytes + h_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + q_desc, [i_b, i_t_start, i_hk, i_k * BK], tma_bar_qh, q_smem + ) + tma.async_copy_global_to_shared( + h_desc, [i_b, i_tg, i_h, i_k * BK, i_v * BV], tma_bar_qh, h_smem + ) + # Load k for A computation + mbarrier.expect(tma_bar_kv, k_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + k_desc, [i_b, i_t_start, i_hk, i_k * BK], tma_bar_kv, k_smem + ) + # wait qh, compute o = q @ h + mbarrier.wait(tma_bar_qh, phase=tma_phase) + q_smem_2d = q_smem.reshape([BT, BK]) # [1, BT, 1, BK] -> [BT, BK] + h_smem_2d = h_smem.reshape([BK, BV]) # [1, 1, 1, BK, BV] -> [BK, BV] + # [BT, BK] @ [BK, BV] -> [BT, BV] + tcgen05_mma(q_smem_2d, h_smem_2d, o_tmem, use_acc=use_acc) + # wait k (overlaps with o MMA), compute A = q @ k_t + mbarrier.wait(tma_bar_kv, phase=tma_phase) + tma_phase ^= 1 + k_t = k_smem.reshape([BT, BK]).permute( + (1, 0) + ) # [1, BT, 1, BK] -> [BT, BK] -> [BK, BT] + # [BT, BK] @ [BK, BT] -> [BT, BT] + tcgen05_mma(q_smem_2d, k_t, A_tmem, use_acc=use_acc) + # single commit mma + tcgen05_commit(mma_bar) + mbarrier.wait(mma_bar, phase=mma_phase) + mma_phase ^= 1 + use_acc = True + # async load v + v_smem = gl.allocate_shared_memory(dtype, v_desc.block_type.shape, v_desc.layout) + mbarrier.expect(tma_bar_kv, v_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + v_desc, [i_b, i_t_start, i_h, i_v * BV], tma_bar_kv, v_smem + ) + + o_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, + [BT, BV], + o_tmem_layout, + num_warps, + ) + A_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, + [BT, BT], + A_tmem_layout, + num_warps, + ) + o_reg = o_tmem.load(o_reg_layout) + A_reg = A_tmem.load(A_reg_layout) + if USE_G: + g_layout_o: gl.constexpr = gl.SliceLayout(dim=1, parent=o_reg_layout) + # Use the chunk ID of the current sequence. + g_idx = i_t * BT + gl.arange(0, BT, layout=g_layout_o) + g_offs = (bos + g_idx) * H + i_h + b_g = gl.load(g + g_offs, mask=g_idx < T, other=0.0) + o_reg = o_reg * gl.exp(b_g)[:, None] + g_layout_A_row: gl.constexpr = gl.SliceLayout(dim=1, parent=A_reg_layout) + g_layout_A_col: gl.constexpr = gl.SliceLayout(dim=0, parent=A_reg_layout) + b_g_row = gl.convert_layout(b_g, g_layout_A_row) + b_g_col = gl.convert_layout(b_g, g_layout_A_col) + A_reg = A_reg * gl.exp(b_g_row[:, None] - b_g_col[None, :]) + # causal mask + # col_limit_right[row_idx] indicates the number of visible columns in each row + # for example: BT=64, col_limit_right = [1, 2, 3, ..., 63, 64] + # col_limit_right = gl.minimum(gl.arange(0, BT)[:, None] + 1, T - i_t * BT) + col_limit_right = gl.arange(0, BT)[:, None] + 1 + A_reg = _apply_causal_mask(A_reg, col_limit_right) + + A_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BT, BT], dtype) + A_smem = gl.allocate_shared_memory(dtype, [BT, BT], A_smem_layout) + A_smem.store(A_reg.to(dtype)) + # fence A_smem + fence_async_shared() + acc_tmem = allocate_tensor_memory(gl.float32, [BT, BV], o_tmem_layout) + # wait v_smem + mbarrier.wait(tma_bar_kv, phase=tma_phase) + mbarrier.invalidate(tma_bar_kv) + # intra chunk A @ v + v_smem_2d = v_smem.reshape([BT, BV]) # [1, BT, 1, BV] -> [BT, BV] + tcgen05_mma(A_smem, v_smem_2d, acc_tmem, use_acc=False) + tcgen05_commit(mma_bar) + mbarrier.wait(mma_bar, phase=mma_phase) + mbarrier.invalidate(mma_bar) + acc = acc_tmem.load(o_reg_layout) + o_reg = o_reg * scale + acc * scale + # store o to global memory + if IS_VARLEN: + # for example: T=126, BT=64, i_t=1 → t_limit_right = min(126-64, 64) = min(62, 64) = 62 + t_limit_right = gl.minimum(T - i_t * BT, BT) + offsets_layout: gl.constexpr = gl.SliceLayout( + 0, gl.BlockedLayout([1, 4], [32, 1], [1, gl.num_warps()], [1, 0]) + ) + t_offsets = gl.arange(0, BT, layout=offsets_layout) # [0, 1, 2, ..., 63] + mask_o = t_offsets < t_limit_right # [T, T, ..., T, F, F] + # use 0x7FFFFFFF(Maximum value of a 32-bit, 2,147,483,647), when out of bounds o.view(T, H * V), TMA skips + x_offsets = gl.where(mask_o, i_t_start + t_offsets, 0x7FFFFFFF) + o_smem_2d = gl.allocate_shared_memory(dtype, [BT, BV], o_desc.layout) + o_smem_2d.store(o_reg.to(dtype)) + fence_async_shared() + tma.async_scatter(o_desc, x_offsets, i_h * V + i_v * BV, o_smem_2d) + else: + o_smem = gl.allocate_shared_memory( + dtype, o_desc.block_type.shape, o_desc.layout + ) + o_smem_2d = o_smem.reshape([BT, BV]) # [1, BT, 1, BV] -> [BT, BV] + o_smem_2d.store(o_reg.to(dtype)) # fp32 -> bf16/fp16 + fence_async_shared() + tma.async_copy_shared_to_global(o_desc, [i_b, i_t_start, i_h, i_v * BV], o_smem) + tma.store_wait(pendings=0) diff --git a/python/sglang/srt/layers/attention/fla/gluon/wy_fast_gluon.py b/python/sglang/srt/layers/attention/fla/gluon/wy_fast_gluon.py new file mode 100644 index 000000000000..d4caf35f1826 --- /dev/null +++ b/python/sglang/srt/layers/attention/fla/gluon/wy_fast_gluon.py @@ -0,0 +1,245 @@ +from sglang.srt.layers.attention.fla.gluon import ( + TensorMemoryLayout, + allocate_tensor_memory, + fence_async_shared, + get_tmem_reg_layout, + gl, + gluon, + mbarrier, + tcgen05_commit, + tcgen05_mma, + tma, +) + + +@gluon.jit +def recompute_w_u_fwd_kernel_gluon( + k_desc, + v_desc, + w_desc, + u_desc, + A_desc, + beta, + g, + cu_seqlens, + chunk_indices, + T, + H: gl.constexpr, + HK: gl.constexpr, + K: gl.constexpr, + V: gl.constexpr, + BT: gl.constexpr, + BK: gl.constexpr, + BV: gl.constexpr, + USE_G: gl.constexpr, + IS_VARLEN: gl.constexpr, +): + i_t, i_bh = gl.program_id(0), gl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + i_hk = i_h // (H // HK) + i_tg = i_t + + NUM_WARPS: gl.constexpr = gl.num_warps() + if IS_VARLEN: + i_n, i_t = gl.load(chunk_indices + i_t * 2).to(gl.int32), gl.load( + chunk_indices + i_t * 2 + 1 + ).to(gl.int32) + bos, eos = gl.load(cu_seqlens + i_n).to(gl.int32), gl.load( + cu_seqlens + i_n + 1 + ).to(gl.int32) + T = eos - bos + # TMA coordinate: qkvo=[0, bos+i_t*BT, ...], h=[0, i_tg, ...] + i_b, i_t_start = 0, bos + i_t * BT + else: + bos, eos = i_b * T, i_b * T + T + # TMA coordinate: qkvo=[i_b, i_t*BT, ...], h=[i_b, i_t, ...] + i_b, i_t_start = i_b, i_t * BT + + dtype: gl.constexpr = k_desc.dtype + + # allocate smem and init mbarrier state + k_smem = gl.allocate_shared_memory(dtype, k_desc.block_type.shape, k_desc.layout) + k_tmem_layout: gl.constexpr = TensorMemoryLayout([BT, BK], col_stride=1) + k_tmem = allocate_tensor_memory(dtype, [BT, BK], k_tmem_layout) + + v_smem = gl.allocate_shared_memory(dtype, v_desc.block_type.shape, v_desc.layout) + v_tmem_layout: gl.constexpr = TensorMemoryLayout([BT, BV], col_stride=1) + v_tmem = allocate_tensor_memory(v_desc.dtype, [BT, BV], v_tmem_layout) + tma_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_bar, count=1) + + w_tmem = allocate_tensor_memory(gl.float32, [BT, BK], k_tmem_layout) + + # fp32 for accumulator + u_tmem = allocate_tensor_memory(gl.float32, [BT, BV], v_tmem_layout) + + A_smem = gl.allocate_shared_memory(dtype, A_desc.block_type.shape, A_desc.layout) + A_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(A_bar, count=1) + + mma_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(mma_bar, count=1) + + tma_phase = 0 + mma_phase = 0 + + beta_layout: gl.constexpr = gl.BlockedLayout( + [1, 1], [32, 1], [1, NUM_WARPS], [1, 0] + ) + layout_1d_beta_x: gl.constexpr = gl.SliceLayout(dim=1, parent=beta_layout) + layout_1d_beta_y: gl.constexpr = gl.SliceLayout(dim=0, parent=beta_layout) + indices_t = i_t * BT + gl.arange(0, BT, layout=layout_1d_beta_x) + beta += bos * H + i_h + mask = indices_t < T + beta_reg = gl.load(beta + indices_t * H, mask=mask, other=0.0)[:, None] + + mbarrier.expect(A_bar, A_desc.block_type.nbytes) + tma.async_copy_global_to_shared(A_desc, [i_b, i_t_start, i_h, 0], A_bar, A_smem) + + v_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, (BT, BV), v_tmem_layout, NUM_WARPS + ) + + if USE_G: + g += bos * H + i_h + g_reg = gl.load(g + indices_t * H, mask=mask, other=0.0) + g_reg = gl.exp(g_reg)[:, None] + + k_reg_layout: gl.constexpr = get_tmem_reg_layout( + gl.float32, (BT, BK), k_tmem_layout, NUM_WARPS + ) + + tma_k_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(tma_k_bar, count=1) + mma_k_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(mma_k_bar, count=1) + tma_phase_k = 0 + mma_phase_k = 0 + + # G -> S + mbarrier.expect(tma_bar, v_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + v_desc, [i_b, i_t_start, i_h, 0 * BV], tma_bar, v_smem + ) + + # G -> S + mbarrier.expect(tma_k_bar, k_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + k_desc, [i_b, i_t_start, i_hk, 0 * BK], tma_k_bar, k_smem + ) + + NV = gl.cdiv(V, BV) + for i_v in range(NV): + mbarrier.wait(tma_bar, phase=tma_phase) + # v * beta + v_smem_2d = v_smem.reshape([BT, BV]) + v_reg = v_smem_2d.load(v_reg_layout) + beta_reg_convert_v = gl.convert_layout(beta_reg, layout=v_reg_layout) + vb_reg = v_reg * beta_reg_convert_v + vb_reg = vb_reg.to(v_desc.dtype) + + # A @ vb + v_smem_2d.store(vb_reg) + fence_async_shared() + if i_v == 0: + mbarrier.wait(A_bar, phase=0) + # [BT, BT] @ [BT, BV] -> [BT, BV] + A_smem_2d = A_smem.reshape([BT, BT]) + tcgen05_mma(A_smem_2d, v_smem_2d, u_tmem, use_acc=False) + tcgen05_commit(mma_bar) + # prefetch next loop + if i_v < NV - 1: + mbarrier.expect(tma_bar, v_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + v_desc, [i_b, i_t_start, i_h, (i_v + 1) * BV], tma_bar, v_smem + ) + mbarrier.wait(mma_bar, phase=mma_phase) + + u_reg = u_tmem.load(v_reg_layout) + if IS_VARLEN: + t_limit_right = gl.minimum(T - i_t * BT, BT) + offsets_u_layout: gl.constexpr = gl.SliceLayout( + 0, gl.BlockedLayout([1, 4], [32, 1], [1, gl.num_warps()], [1, 0]) + ) + t_offsets = gl.arange(0, BT, layout=offsets_u_layout) + mask_o = t_offsets < t_limit_right + x_offsets = gl.where(mask_o, i_t_start + t_offsets, 0x7FFFFFFF) + u_smem_2d = gl.allocate_shared_memory(dtype, [BT, BV], u_desc.layout) + u_smem_2d.store(u_reg.to(dtype)) + fence_async_shared() + tma.async_scatter(u_desc, x_offsets, i_h * V + i_v * BV, u_smem_2d) + else: + u_smem = gl.allocate_shared_memory( + dtype, u_desc.block_type.shape, u_desc.layout + ) + u_smem_2d = u_smem.reshape([BT, BV]) + u_smem_2d.store(u_reg.to(dtype)) + fence_async_shared() + tma.async_copy_shared_to_global( + u_desc, [i_b, i_t_start, i_h, i_v * BV], u_smem + ) + # guarantee all tma store ops are completed + tma.store_wait(pendings=0) + + mbarrier.wait(tma_k_bar, phase=tma_phase_k) + # k * beta + k_smem_2d = k_smem.reshape([BT, BK]) + k_reg = k_smem_2d.load(k_reg_layout) + beta_reg_convert_k = gl.convert_layout(beta_reg, layout=k_reg_layout) + kb_reg = k_reg * beta_reg_convert_k + + if USE_G: + g_reg_convert_k = gl.convert_layout(g_reg, layout=k_reg_layout) + kb_reg *= g_reg_convert_k + kb_reg = kb_reg.to(k_desc.dtype) + + k_smem_2d.store(kb_reg) + fence_async_shared() + tcgen05_mma(A_smem_2d, k_smem_2d, w_tmem, use_acc=False) + tcgen05_commit(mma_k_bar) + # prefetch next loop + if i_v < NV - 1: + mbarrier.expect(tma_k_bar, k_desc.block_type.nbytes) + tma.async_copy_global_to_shared( + k_desc, [i_b, i_t_start, i_hk, (i_v + 1) * BK], tma_k_bar, k_smem + ) + mbarrier.wait(mma_k_bar, phase=mma_phase_k) + + w_reg = w_tmem.load(k_reg_layout) + if IS_VARLEN: + t_limit_right = gl.minimum(T - i_t * BT, BT) + offsets_w_layout: gl.constexpr = gl.SliceLayout( + 0, gl.BlockedLayout([1, 4], [32, 1], [1, gl.num_warps()], [1, 0]) + ) + t_offsets = gl.arange(0, BT, layout=offsets_w_layout) + mask_o = t_offsets < t_limit_right + x_offsets = gl.where(mask_o, i_t_start + t_offsets, 0x7FFFFFFF) + w_smem_2d = gl.allocate_shared_memory(dtype, [BT, BK], w_desc.layout) + w_smem_2d.store(w_reg.to(dtype)) + fence_async_shared() + tma.async_scatter(w_desc, x_offsets, i_h * K + i_v * BK, w_smem_2d) + else: + w_smem = gl.allocate_shared_memory( + dtype, w_desc.block_type.shape, w_desc.layout + ) + # store w + w_smem_2d = w_smem.reshape([BT, BK]) + w_smem_2d.store(w_reg.to(dtype)) + fence_async_shared() + tma.async_copy_shared_to_global( + w_desc, [i_b, i_t_start, i_h, i_v * BK], w_smem + ) + + tma_phase ^= 1 + mma_phase ^= 1 + tma_phase_k ^= 1 + mma_phase_k ^= 1 + + # guarantee all tma store ops are completed + tma.store_wait(pendings=0) + + mbarrier.invalidate(tma_bar) + mbarrier.invalidate(mma_bar) + mbarrier.invalidate(A_bar) + mbarrier.invalidate(tma_k_bar) + mbarrier.invalidate(mma_k_bar) diff --git a/python/sglang/srt/layers/attention/fla/utils.py b/python/sglang/srt/layers/attention/fla/utils.py index af6ca3d6e572..085b6d08bc95 100644 --- a/python/sglang/srt/layers/attention/fla/utils.py +++ b/python/sglang/srt/layers/attention/fla/utils.py @@ -20,6 +20,9 @@ COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_CUMSUM_SCALAR_VECTORIZATION = ( + os.getenv("FLA_CUMSUM_SCALAR_VECTORIZATION", "1") == "1" +) @lru_cache(maxsize=1) @@ -267,6 +270,16 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: # Nvidia Ampere or newer, haven't check AMD and intel yet. is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8 is_gather_supported = hasattr(triton.language, "gather") +IS_GLUON_SUPPORTED = ( + (is_nvidia and torch.cuda.get_device_capability(0)[0] == 10) + and os.environ.get("FLA_USE_GLUON", "1") == "1" + and version.parse(triton.__version__) >= version.parse("3.6.0") +) + +if IS_GLUON_SUPPORTED: + logger.info( + "Gluon is supported, using Gluon by default. Set FLA_USE_GLUON=0 to disable." + ) def get_all_max_shared_mem(): diff --git a/python/sglang/srt/layers/attention/fla/wy_fast.py b/python/sglang/srt/layers/attention/fla/wy_fast.py index 757e5621087b..4c1214817e8a 100644 --- a/python/sglang/srt/layers/attention/fla/wy_fast.py +++ b/python/sglang/srt/layers/attention/fla/wy_fast.py @@ -9,6 +9,13 @@ import triton.language as tl from sglang.srt.layers.attention.fla.index import prepare_chunk_indices +from sglang.srt.layers.attention.fla.utils import IS_GLUON_SUPPORTED + +if IS_GLUON_SUPPORTED: + from sglang.srt.layers.attention.fla.gluon import TensorDescriptor, gl + from sglang.srt.layers.attention.fla.gluon.wy_fast_gluon import ( + recompute_w_u_fwd_kernel_gluon, + ) # @triton.autotune( @@ -128,28 +135,73 @@ def recompute_w_u_fwd( BV = 64 u = torch.empty_like(v) w = k.new_empty(B, T, H, K) - recompute_w_u_fwd_kernel[(NT, B * H)]( - k=k, - v=v, - beta=beta, - w=w, - u=u, - A=A, - g=g_cumsum, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - BK=BK, - BV=BV, - IS_VARLEN=cu_seqlens is not None, - num_warps=4, - num_stages=3, - ) + + if IS_GLUON_SUPPORTED: + # tma desc init + kw_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BK], gl.float16) + vu_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BV], gl.float16) + A_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BT], gl.float16) + k_desc = TensorDescriptor.from_tensor(k, [1, BT, 1, BK], kw_layout) + v_desc = TensorDescriptor.from_tensor(v, [1, BT, 1, BV], vu_layout) + A_desc = TensorDescriptor.from_tensor(A, [1, BT, 1, BT], A_layout) + if cu_seqlens is not None: + w_layout = gl.NVMMASharedLayout.get_default_for([BT, BK], gl.float16) + w_desc = TensorDescriptor.from_tensor( + w.view(B * T, H * K), [1, BK], w_layout + ) + u_layout = gl.NVMMASharedLayout.get_default_for([BT, BV], gl.float16) + u_desc = TensorDescriptor.from_tensor( + u.view(B * T, H * V), [1, BV], u_layout + ) + else: + w_desc = TensorDescriptor.from_tensor(w, [1, BT, 1, BK], kw_layout) + u_desc = TensorDescriptor.from_tensor(u, [1, BT, 1, BV], vu_layout) + + recompute_w_u_fwd_kernel_gluon[(NT, B * H)]( + k_desc=k_desc, + v_desc=v_desc, + w_desc=w_desc, + u_desc=u_desc, + A_desc=A_desc, + beta=beta, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + HK=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_G=g_cumsum is not None, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + ) + else: + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + IS_VARLEN=cu_seqlens is not None, + num_warps=4, + num_stages=3, + ) return w, u