Skip to content
203 changes: 198 additions & 5 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,65 @@ def get_num_kv_splits_triton(
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)


def update_sliding_window_buffer(
window_kv_indptr,
req_to_token,
sliding_window_size,
seq_lens,
req_pool_indices,
bs,
device,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_indices = torch.empty(
window_kv_indptr[-1], dtype=torch.int32, device=device
)
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
window_kv_lens,
window_kv_indptr,
window_kv_start_idx,
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_indices, window_kv_lens


def update_sliding_window_buffer_cuda_graph(
window_kv_indptr,
window_kv_indices,
req_to_token,
sliding_window_size,
seq_lens,
req_pool_indices,
bs,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
window_kv_start_idx = seq_lens - window_kv_lens
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
window_kv_lens,
window_kv_indptr,
window_kv_start_idx,
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_lens


@dataclass
class ForwardMetadata:
attn_logits: torch.Tensor
Expand All @@ -83,6 +142,10 @@ class ForwardMetadata:
qo_indptr: torch.Tensor
custom_mask: torch.Tensor
mask_indptr: torch.Tensor
# Sliding window
window_kv_indptr: torch.Tensor
window_kv_indices: torch.Tensor
window_num_kv_splits: torch.Tensor


class TritonAttnBackend(AttentionBackend):
Expand All @@ -109,13 +172,32 @@ def __init__(

max_bs = model_runner.req_to_token_pool.size

assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
self.sliding_window_size = model_runner.sliding_window_size

# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf

# If sliding window is enabled, we might need two sets of buffers
# because of interleaved attention types (e.g. for Gemma3)
self.window_kv_indptr = None
if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indptr_buf is None:
self.window_kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
# When provided a buffer, create a clone for the second buffer
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)

self.req_to_token = model_runner.req_to_token_pool.req_to_token

if not self.skip_prefill:
Expand Down Expand Up @@ -190,6 +272,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
window_kv_indptr = self.window_kv_indptr
window_kv_indices = None
window_num_kv_splits = None
spec_info = forward_batch.spec_info

if forward_batch.forward_mode.is_decode_or_idle():
Expand All @@ -208,6 +293,26 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices,
self.req_to_token.stride(0),
)
# Sliding window
if (
self.sliding_window_size is not None
and self.sliding_window_size > 0
):
window_kv_indptr, window_kv_indices, window_kv_lens = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
)
)
window_num_kv_splits = torch.empty(
(bs,), dtype=torch.int32, device=self.device
)
self.get_num_kv_splits(window_num_kv_splits, window_kv_lens)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
Expand All @@ -223,14 +328,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
device=self.device,
)
num_kv_splits = torch.empty((bs,), dtype=torch.int32, device=self.device)

self.get_num_kv_splits(num_kv_splits, forward_batch.seq_lens)

qo_indptr = None
custom_mask = None
mask_indptr = None
max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
# TODO: Support sliding window in spec inference
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
Expand Down Expand Up @@ -302,6 +407,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices,
self.req_to_token.stride(0),
)
# Sliding window
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.extend_prefix_lens,
forward_batch.req_pool_indices,
bs,
self.device,
)

qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
Expand All @@ -323,6 +439,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
qo_indptr,
custom_mask,
mask_indptr,
window_kv_indptr,
window_kv_indices,
window_num_kv_splits,
)

def init_cuda_graph_state(
Expand Down Expand Up @@ -357,6 +476,20 @@ def init_cuda_graph_state(
device=self.device,
)

if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros(
(max_bs * self.sliding_window_size),
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)

self.cuda_graph_window_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
)

def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
Expand All @@ -368,6 +501,9 @@ def init_forward_metadata_capture_cuda_graph(
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
assert encoder_lens is None, "Not supported"
window_kv_indptr = self.window_kv_indptr
window_kv_indices = None
window_num_kv_splits = None

if forward_mode.is_decode_or_idle():
if spec_info is None:
Expand All @@ -384,6 +520,21 @@ def init_forward_metadata_capture_cuda_graph(
kv_indices,
self.req_to_token.stride(0),
)
if (
self.sliding_window_size is not None
and self.sliding_window_size > 0
):
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

Expand Down Expand Up @@ -439,6 +590,9 @@ def init_forward_metadata_capture_cuda_graph(
qo_indptr,
custom_mask,
mask_indptr,
window_kv_indptr,
window_kv_indices,
window_num_kv_splits,
)

def init_forward_metadata_replay_cuda_graph(
Expand Down Expand Up @@ -471,11 +625,31 @@ def init_forward_metadata_replay_cuda_graph(
self.req_to_token.stride(0),
)
num_token = bs
if (
self.sliding_window_size is not None
and self.sliding_window_size > 0
):
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices[:bs],
bs,
)
self.get_num_kv_splits(
window_num_kv_splits[:num_token], window_kv_lens[:bs]
)

else:
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
num_token = spec_info.kv_indptr.shape[0] - 1
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])

elif forward_mode.is_target_verify():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs = len(req_pool_indices)
Expand Down Expand Up @@ -536,6 +710,17 @@ def forward_extend(
if layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False

if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
sliding_window_size = (
layer.sliding_window_size
) # Needed for sliding window mask
kv_indptr = self.forward_metadata.window_kv_indptr
kv_indices = self.forward_metadata.window_kv_indices
else:
sliding_window_size = -1
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices

self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
Expand All @@ -544,14 +729,15 @@ def forward_extend(
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
kv_indptr,
kv_indices,
self.forward_metadata.custom_mask,
causal,
self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
sliding_window_size,
)
return o

Expand Down Expand Up @@ -579,13 +765,20 @@ def forward_decode(
layer, forward_batch.out_cache_loc, k, v
)

if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
kv_indptr = self.forward_metadata.window_kv_indptr
kv_indices = self.forward_metadata.window_kv_indices
else:
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices

self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
kv_indptr,
kv_indices,
self.forward_metadata.attn_logits,
self.forward_metadata.attn_lse,
self.forward_metadata.num_kv_splits,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _fwd_kernel(
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
Expand Down Expand Up @@ -163,6 +164,7 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
Expand All @@ -173,10 +175,14 @@ def _fwd_kernel(
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
else:
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
final_mask &= custom_mask
if SLIDING_WINDOW_SIZE > 0:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -314,6 +320,7 @@ def extend_attention_fwd(
sm_scale=None,
logit_cap=0.0,
skip_prefix_custom_mask=True,
sliding_window_size=-1,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
Expand Down Expand Up @@ -412,6 +419,7 @@ def extend_attention_fwd(
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
Expand Down
Loading
Loading