Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions python/sglang/srt/layers/attention/fla/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
from einops import rearrange

from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
from sglang.srt.layers.attention.fla.chunk_fwd import chunk_gated_delta_rule_fwd_intra
from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o
from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (
chunk_scaled_dot_kkt_fwd,
)
Comment thread
yuan-luo marked this conversation as resolved.
from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
from sglang.srt.layers.attention.fla.index import (
prepare_chunk_indices,
)
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
from sglang.srt.layers.attention.fla.solve_tril import solve_tril
from sglang.srt.layers.attention.fla.utils import (
SUPPRESS_LEVEL,
autocast_custom_fwd,
input_guard,
)
from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd

CHUNK_SIZE = 64
Comment thread
yuan-luo marked this conversation as resolved.


def chunk_gated_delta_rule_fwd(
Expand All @@ -33,21 +34,20 @@ def chunk_gated_delta_rule_fwd(
initial_state: torch.Tensor,
initial_state_indices: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_indices: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
g = chunk_local_cumsum(g, chunk_size=CHUNK_SIZE, cu_seqlens=cu_seqlens)

Comment thread
yuan-luo marked this conversation as resolved.
# fused kkt + solve_tril + recompute_w_u
w, u, A = chunk_gated_delta_rule_fwd_intra(
k=k,
v=v,
g=g,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)

h, v_new = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
Expand Down Expand Up @@ -97,6 +97,11 @@ def forward(
q = l2norm_fwd(q)
k = l2norm_fwd(k)

chunk_indices = (
prepare_chunk_indices(cu_seqlens, CHUNK_SIZE)
if cu_seqlens is not None
else None
)
g, o, A, w, h, v_new = chunk_gated_delta_rule_fwd(
q=q,
k=k,
Expand All @@ -107,6 +112,7 @@ def forward(
initial_state=initial_state,
initial_state_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
return o.to(q.dtype), h

Expand Down
Loading
Loading