Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 85 additions & 14 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)

self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator

if not self.skip_prefill:
self.qo_indptr = torch.zeros(
Expand Down Expand Up @@ -197,6 +198,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)
window_num_kv_splits = torch.empty(
Expand Down Expand Up @@ -225,7 +227,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
mask_indptr = None
max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
# TODO: Support sliding window in spec inference
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
Expand All @@ -250,6 +251,20 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
self.req_to_token.stride(0),
)

if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, window_kv_lens = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)

custom_mask = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (
forward_batch.seq_lens + self.num_draft_tokens
Expand Down Expand Up @@ -308,6 +323,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)

qo_indptr = self.qo_indptr
Expand Down Expand Up @@ -423,14 +439,17 @@ def init_forward_metadata_capture_cuda_graph(
):
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
window_kv_indptr, window_kv_indices, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
Expand Down Expand Up @@ -464,6 +483,22 @@ def init_forward_metadata_capture_cuda_graph(
self.req_to_token.stride(0),
)

if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)

custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
Expand Down Expand Up @@ -557,14 +592,15 @@ def init_forward_metadata_replay_cuda_graph(
):
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices[:bs],
bs,
self.token_to_kv_pool_allocator,
)
self.get_num_kv_splits(
window_num_kv_splits[:num_token], window_kv_lens[:bs]
Expand Down Expand Up @@ -599,6 +635,19 @@ def init_forward_metadata_replay_cuda_graph(
kv_indices,
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
Expand Down Expand Up @@ -637,6 +686,7 @@ def forward_extend(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
Expand Down Expand Up @@ -680,7 +730,8 @@ def forward_extend(
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
sliding_window_size,
sliding_window_size=sliding_window_size,
sk=sk,
)
return o

Expand All @@ -692,6 +743,7 @@ def forward_decode(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand Down Expand Up @@ -728,6 +780,7 @@ def forward_decode(
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
sk=sk,
)
return o

Expand Down Expand Up @@ -932,10 +985,11 @@ def update_sliding_window_buffer(
req_pool_indices,
bs,
device,
token_to_kv_pool_allocator=None,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
torch.tensor(sliding_window_size),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
Expand All @@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
window_kv_indices,
req_to_token.stride(0),
)
# full to swa index mapping
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
kv_last_index = window_kv_indptr[-1]
window_kv_indices[:kv_last_index] = (
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens


Expand All @@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
seq_lens,
req_pool_indices,
bs,
token_to_kv_pool_allocator=None,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
torch.tensor(sliding_window_size),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
Expand All @@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_lens
# full to swa index mapping
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
kv_last_index = window_kv_indptr[-1]
window_kv_indices[:kv_last_index] = (
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
O,
kv_indptr,
num_kv_splits,
sk_ptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
Expand All @@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
HAS_SK: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
Expand Down Expand Up @@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max

if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
e_sum += tl.exp(cur_sk - e_max)

tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
Expand All @@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk=None,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)

MAX_KV_SPLITS = max_kv_splits
HAS_SK = sk is not None

extra_kargs = {}
if _is_hip:
Expand All @@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
o,
kv_indptr,
num_kv_splits,
sk,
logits.stride(0),
logits.stride(1),
logits.stride(2),
Expand All @@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
HAS_SK=HAS_SK,
num_warps=4,
num_stages=2,
**extra_kargs,
Expand All @@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
_decode_att_m_fwd(
q,
Expand All @@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
)


Expand All @@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
_decode_grouped_att_m_fwd(
q,
Expand All @@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
)


Expand All @@ -687,6 +701,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
Expand All @@ -709,6 +724,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
)
else:
# GQA/MQA/MLA
Expand All @@ -725,4 +741,5 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
)
Loading
Loading