Skip to content
Open
207 changes: 178 additions & 29 deletions python/sglang/srt/layers/attention/fla/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
prepare_chunk_offsets,
)
from sglang.srt.layers.attention.fla.op import exp, safe_exp
from sglang.srt.layers.attention.fla.utils import is_nvidia_hopper
from sglang.srt.layers.attention.fla.utils import IS_GLUON_SUPPORTED, is_nvidia_hopper

if IS_GLUON_SUPPORTED:
from sglang.srt.layers.attention.fla.gluon import TensorDescriptor, gl
from sglang.srt.layers.attention.fla.gluon.chunk_delta_h_gluon import (
chunk_gated_delta_rule_fwd_kernel_h_blockdim64_gluon,
)


NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
CHUNK_SIZE = 64
Expand Down Expand Up @@ -55,6 +62,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
INPLACE_UPDATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
TRANSPOSE_STATE: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
Expand Down Expand Up @@ -101,23 +109,47 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(

# load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
if not TRANSPOSE_STATE:
p_h0_1 = tl.make_block_ptr(
h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
else:
# Column major: shape (K, V), stride (1, K)
p_h0_1 = tl.make_block_ptr(
h0, (K, V), (1, K), (0, i_v * BV), (64, BV), (0, 1)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(
h0, (K, V), (1, K), (64, i_v * BV), (64, BV), (0, 1)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(
h0, (K, V), (1, K), (128, i_v * BV), (64, BV), (0, 1)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(
h0, (K, V), (1, K), (192, i_v * BV), (64, BV), (0, 1)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)

# main recurrence
for i_t in range(NT):
Expand Down Expand Up @@ -252,23 +284,47 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(

# epilogue
if INPLACE_UPDATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
if not TRANSPOSE_STATE:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
else:
# Column major: shape (K, V), stride (1, K)
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
ht, (K, V), (1, K), (0, i_v * BV), (64, BV), (0, 1)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(
ht, (K, V), (1, K), (64, i_v * BV), (64, BV), (0, 1)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(
ht, (K, V), (1, K), (128, i_v * BV), (64, BV), (0, 1)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(
ht, (K, V), (1, K), (192, i_v * BV), (64, BV), (0, 1)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))


def chunk_gated_delta_rule_fwd_h(
Expand All @@ -281,6 +337,7 @@ def chunk_gated_delta_rule_fwd_h(
initial_state_indices: Optional[torch.Tensor] = None,
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
transpose_state: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]
Expand All @@ -306,6 +363,97 @@ def chunk_gated_delta_rule_fwd_h(

v_new = torch.empty_like(u) if save_new_value else None

if IS_GLUON_SUPPORTED:
BK = K
BV = 64
IS_VARLEN = cu_seqlens is not None
num_warps = 4

gl_dtype = gl.bfloat16 if k.dtype == torch.bfloat16 else gl.float16
# k, w: [B, T, HK, K] / [B, T, H, K]
kw_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BK], gl_dtype)
# v, v_new: [B, T, H, V]
v_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BV], gl_dtype)
# h: [B, NT, H, K, V]
h_layout = gl.NVMMASharedLayout.get_default_for([1, 1, 1, BK, BV], gl_dtype)

k_desc = TensorDescriptor.from_tensor(k, [1, BT, 1, BK], kw_layout)
w_desc = TensorDescriptor.from_tensor(w, [1, BT, 1, BK], kw_layout)
v_desc = TensorDescriptor.from_tensor(u, [1, BT, 1, BV], v_layout)
h_desc = TensorDescriptor.from_tensor(h, [1, 1, 1, BK, BV], h_layout)

if initial_state is not None:
if transpose_state:
# transpose_state=True: state is stored in [V, K] physical layout.
# We need a view with shape [..., V, K] that is K-contiguous for TMA alignment.
if initial_state.stride(-1) == 1:
# Already K-contiguous [..., V, K], use directly
h0_view = initial_state
else:
# [..., K, V] with K-stride==1 -> transpose to [..., V, K]
h0_view = initial_state.transpose(-2, -1)
h0_layout = gl.NVMMASharedLayout.get_default_for(
[1, 1, BV, BK], gl.float32
)
h0_desc = TensorDescriptor.from_tensor(
h0_view, [1, 1, BV, BK], h0_layout
)
else:
h0_layout = gl.NVMMASharedLayout.get_default_for(
[1, 1, BK, BV], gl.float32
)
h0_desc = TensorDescriptor.from_tensor(
initial_state, [1, 1, BK, BV], h0_layout
)
else:
h0_desc = None

# For varlen, use scatter layout for v_new to handle boundary correctly
if save_new_value:
if IS_VARLEN:
v_new_scatter_layout = gl.NVMMASharedLayout.get_default_for(
[BT, BV], gl_dtype
)
v_new_desc = TensorDescriptor.from_tensor(
v_new.view(B * T, H * V), [1, BV], v_new_scatter_layout
)
else:
v_new_desc = TensorDescriptor.from_tensor(
v_new, [1, BT, 1, BV], v_layout
)
else:
v_new_desc = None

grid = (triton.cdiv(V, BV), N * H)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64_gluon[grid](
k_desc=k_desc,
v_desc=v_desc,
w_desc=w_desc,
v_new_desc=v_new_desc,
g=g,
h_desc=h_desc,
h0_desc=h0_desc,
initial_state_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
HK=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
USE_G=g is not None,
USE_INITIAL_STATE=initial_state is not None,
INPLACE_UPDATE=initial_state is not None,
SAVE_NEW_VALUE=v_new is not None,
IS_VARLEN=cu_seqlens is not None,
TRANSPOSE_STATE=transpose_state,
num_warps=num_warps,
)
return h, v_new

def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)

Expand Down Expand Up @@ -334,6 +482,7 @@ def grid(meta):
INPLACE_UPDATE=True,
SAVE_NEW_VALUE=v_new is not None,
IS_VARLEN=cu_seqlens is not None,
TRANSPOSE_STATE=transpose_state,
num_warps=4,
num_stages=2,
)
Expand Down
115 changes: 87 additions & 28 deletions python/sglang/srt/layers/attention/fla/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,18 @@

from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
from sglang.srt.layers.attention.fla.op import exp, safe_exp
from sglang.srt.layers.attention.fla.utils import check_shared_mem, is_nvidia_hopper
from sglang.srt.layers.attention.fla.utils import (
IS_GLUON_SUPPORTED,
check_shared_mem,
is_nvidia_hopper,
)

if IS_GLUON_SUPPORTED:
from sglang.srt.layers.attention.fla.gluon import TensorDescriptor, gl
from sglang.srt.layers.attention.fla.gluon.chunk_o_gluon import (
chunk_fwd_kernel_o_gluon,
)


BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
Expand Down Expand Up @@ -135,7 +146,7 @@ def chunk_fwd_o(
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
Expand All @@ -145,30 +156,78 @@ def chunk_fwd_o(

o = torch.zeros_like(v)

def grid(meta):
return (triton.cdiv(V, meta["BV"]), NT, B * H)

chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=128,
BV=64,
USE_G=g is not None,
IS_VARLEN=cu_seqlens is not None,
num_warps=4,
num_stages=2,
)
if IS_GLUON_SUPPORTED:
BK = 128 if K >= 128 else 64
BV = 128 if V >= 128 else 64
IS_VARLEN = cu_seqlens is not None
num_warps = 8 if BT >= 128 else 4

gl_dtype = gl.bfloat16 if q.dtype == torch.bfloat16 else gl.float16
qk_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BK], gl_dtype)
vo_layout = gl.NVMMASharedLayout.get_default_for([1, BT, 1, BV], gl_dtype)
h_layout = gl.NVMMASharedLayout.get_default_for([1, 1, 1, BK, BV], gl_dtype)
q_desc = TensorDescriptor.from_tensor(q, [1, BT, 1, BK], qk_layout)
k_desc = TensorDescriptor.from_tensor(k, [1, BT, 1, BK], qk_layout)
v_desc = TensorDescriptor.from_tensor(v, [1, BT, 1, BV], vo_layout)
h_desc = TensorDescriptor.from_tensor(h, [1, 1, 1, BK, BV], h_layout)
if IS_VARLEN:
o_layout = gl.NVMMASharedLayout.get_default_for([BT, BV], gl_dtype)
o_desc = TensorDescriptor.from_tensor(
o.view(B * T, H * V), [1, BV], o_layout
)
else:
o_desc = TensorDescriptor.from_tensor(o, [1, BT, 1, BV], vo_layout)

grid = (triton.cdiv(V, BV), NT, B * H)
chunk_fwd_kernel_o_gluon[grid](
q_desc=q_desc,
k_desc=k_desc,
v_desc=v_desc,
h_desc=h_desc,
o_desc=o_desc,
g=g,
g_gamma=None,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
HK=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
USE_G=g is not None,
IS_VARLEN=IS_VARLEN,
num_warps=num_warps,
)
else:

def grid(meta):
return (triton.cdiv(V, meta["BV"]), NT, B * H)

chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=128,
BV=64,
USE_G=g is not None,
IS_VARLEN=cu_seqlens is not None,
num_warps=4,
num_stages=2,
)
return o
Loading
Loading