Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def forward_extend(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
sinks=None,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
Expand Down Expand Up @@ -731,7 +731,7 @@ def forward_extend(
layer.scaling,
layer.logit_cap,
sliding_window_size=sliding_window_size,
sk=sk,
sinks=sinks,
)
return o

Expand All @@ -743,7 +743,7 @@ def forward_decode(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
sinks=None,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
Expand Down Expand Up @@ -780,7 +780,7 @@ def forward_decode(
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
sk=sk,
sinks=sinks,
)
return o

Expand Down
32 changes: 16 additions & 16 deletions python/sglang/srt/layers/attention/triton_ops/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
O,
kv_indptr,
num_kv_splits,
sk_ptr,
sink_ptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
Expand All @@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
HAS_SK: tl.constexpr,
HAS_SINK: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
Expand Down Expand Up @@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max

if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
e_sum += tl.exp(cur_sk - e_max)
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
e_sum += tl.exp(cur_sink - e_max)

tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
Expand All @@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk=None,
sinks=None,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)

MAX_KV_SPLITS = max_kv_splits
HAS_SK = sk is not None
HAS_SINK = sinks is not None

extra_kargs = {}
if _is_hip:
Expand All @@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
o,
kv_indptr,
num_kv_splits,
sk,
sinks,
logits.stride(0),
logits.stride(1),
logits.stride(2),
Expand All @@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
HAS_SK=HAS_SK,
HAS_SINK=HAS_SINK,
num_warps=4,
num_stages=2,
**extra_kargs,
Expand All @@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
sinks=None,
):
_decode_att_m_fwd(
q,
Expand All @@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
sinks,
)


Expand All @@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
sinks=None,
):
_decode_grouped_att_m_fwd(
q,
Expand All @@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
sinks,
)


Expand All @@ -701,7 +701,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
sinks=None,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
Expand All @@ -724,7 +724,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
sinks=sinks,
)
else:
# GQA/MQA/MLA
Expand All @@ -741,5 +741,5 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
sinks=sinks,
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _fwd_kernel(
kv_indices,
mask_ptr,
mask_indptr,
sk_ptr,
sink_ptr,
sm_scale,
kv_group_num,
stride_qbs,
Expand Down Expand Up @@ -79,7 +79,7 @@ def _fwd_kernel(
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SK: tl.constexpr,
HAS_SINK: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
Expand Down Expand Up @@ -302,9 +302,9 @@ def _fwd_kernel(

e_max = n_e_max

if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
deno += tl.exp(cur_sk - e_max)
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
deno += tl.exp(cur_sink - e_max)

offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
Expand Down Expand Up @@ -344,7 +344,7 @@ def extend_attention_fwd(
logit_cap=0.0,
skip_prefix_custom_mask=True,
sliding_window_size=-1,
sk=None,
sinks=None,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
Expand Down Expand Up @@ -410,7 +410,7 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask

HAS_SK = sk is not None
HAS_SINK = sinks is not None

grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
Expand All @@ -431,7 +431,7 @@ def extend_attention_fwd(
kv_indices,
custom_mask,
mask_indptr,
sk,
sinks,
sm_scale,
kv_group_num,
q_extend.stride(0),
Expand All @@ -458,7 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
HAS_SK=HAS_SK,
HAS_SINK=HAS_SINK,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
num_stages=num_stages,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sk=self.sinks)
attn_output = self.attn(*inner_state, sinks=self.sinks)
output, _ = self.o_proj(attn_output)
return output

Expand Down
Loading