Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
96972ce
feat: add code completion engine and server
woodx9 Mar 18, 2025
f9ef0b6
delete useless space
woodx9 Mar 18, 2025
7780099
support bert model on native torch backend and triton backend
woodx9 Mar 28, 2025
6b7d0d6
rebase main
woodx9 Mar 29, 2025
773b98f
pre commit fix
woodx9 Mar 29, 2025
15264a6
add unit test for encoder only models
woodx9 Mar 29, 2025
935e0ca
Revert "delete useless space"
woodx9 Mar 29, 2025
a210cb0
Revert "feat: add code completion engine and server"
woodx9 Mar 29, 2025
fe115e4
rename is_causal to causal
woodx9 Mar 29, 2025
974eca0
fix typo
woodx9 Mar 29, 2025
409ab3d
remove useless log
woodx9 Mar 29, 2025
708f4ea
fix encoder model on flash attn
woodx9 Mar 30, 2025
ecfe402
refactor: simplify causal mask logic in Triton attention backend
woodx9 Mar 30, 2025
263503d
fix AttentionType as Enum
woodx9 Mar 30, 2025
9c0fd29
support facebook contriever model
woodx9 Mar 30, 2025
369f610
add speed test
woodx9 Mar 30, 2025
8525914
fix triton backend illegal memory access problem
woodx9 Mar 30, 2025
5f26068
add big batch size test on test encoder embedding models
woodx9 Mar 31, 2025
4534866
add triton and fa3 attention backend test
woodx9 Mar 31, 2025
68b731a
roll back fa3 as attn backend for encoder embedding models
woodx9 Apr 9, 2025
e2e458f
fix lint
woodx9 Apr 9, 2025
02476f0
roll back flash attn change
woodx9 Apr 9, 2025
37ab824
fix triton attention backend test
woodx9 Apr 9, 2025
452534d
Merge branch 'main' into feat/support_bert
zhyncs Apr 14, 2025
a390e44
fix lint error
woodx9 Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.functional import scaled_dot_product_attention

from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -202,6 +203,10 @@ def forward_extend(
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)

causal = True
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False

self._run_sdpa_forward_extend(
q_,
o_,
Expand All @@ -214,7 +219,7 @@ def forward_extend(
forward_batch.extend_seq_lens,
scaling=layer.scaling,
enable_gqa=use_gqa,
causal=not layer.is_cross_attention,
causal=causal,
)
return o

Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var, get_device_core_count

Expand Down Expand Up @@ -528,6 +529,10 @@ def forward_extend(
layer, forward_batch.out_cache_loc, k, v
)

causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False

self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
Expand All @@ -539,6 +544,7 @@ def forward_extend(
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.custom_mask,
causal,
self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len,
layer.scaling,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
):
Expand Down Expand Up @@ -129,6 +130,7 @@ def _fwd_kernel(
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix

offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
)
Expand Down Expand Up @@ -196,7 +198,11 @@ def _fwd_kernel(

# stage 2: compute the triangle part

cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
cur_block_m_end = (
cur_seq_len_extend
if not IS_CAUSAL
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
Expand Down Expand Up @@ -243,12 +249,15 @@ def _fwd_kernel(
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
else:
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
else:
mask_non_causal = mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_non_causal, qk, float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -299,6 +308,7 @@ def extend_attention_fwd(
kv_indptr,
kv_indices,
custom_mask,
is_causal,
mask_indptr,
max_len_extend,
sm_scale=None,
Expand Down Expand Up @@ -411,6 +421,7 @@ def extend_attention_fwd(
Lq=Lq,
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# ==============================================================================
"""Radix attention."""

from enum import Enum
from typing import Optional

from torch import nn
Expand All @@ -22,6 +23,18 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


class AttentionType(Enum):
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""

# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"


class RadixAttention(nn.Module):
"""
The attention layer implementation.
Expand All @@ -39,6 +52,7 @@ def __init__(
sliding_window_size: int = -1,
is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
attn_type=AttentionType.DECODER,
prefix: str = "",
use_irope: bool = False,
):
Expand All @@ -64,6 +78,7 @@ def __init__(
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
self.attn_type = attn_type

def forward(
self,
Expand Down
Loading
Loading