Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/sglang/srt/constrained/xgrammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,16 @@ def __init__(
):
super().__init__()

# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
override_stop_tokens = None
if hasattr(tokenizer, "init_xgrammar"):
# For special tokenizer
tokenizer_info, override_stop_tokens = tokenizer.init_xgrammar()
else:
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
)
override_stop_tokens = None

self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def get_tokenizer(
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_name.endswith(".json"):
from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer

return TiktokenTokenizer(tokenizer_name)

if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
Expand Down
18 changes: 16 additions & 2 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput


def logit_capping_mod(logit_capping_method, logit_cap):
# positive logit_cap -> tanh cap
if logit_capping_method == "tanh":
return logit_cap
else:
raise ValueError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Raising a ValueError without a descriptive message makes debugging difficult. It's better to include information about why the error is being raised, such as the unsupported method that was passed.

Suggested change
raise ValueError()
raise ValueError(f"Unsupported logit_capping_method: {logit_capping_method}")



@dataclass
class ForwardMetadata:
attn_logits: torch.Tensor
Expand Down Expand Up @@ -718,6 +726,8 @@ def forward_extend(
layer, forward_batch.out_cache_loc, k, v
)

logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)

causal = True
if layer.attn_type == AttentionType.ENCODER_ONLY:
causal = False
Expand Down Expand Up @@ -750,10 +760,11 @@ def forward_extend(
self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
logit_cap=logits_soft_cap,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_kv_offsets=window_kv_offsets,
xai_temperature_len=layer.xai_temperature_len,
)
return o

Expand All @@ -777,6 +788,8 @@ def forward_decode(
else:
o = torch.empty_like(q)

logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
Expand All @@ -801,8 +814,9 @@ def forward_decode(
self.forward_metadata.num_kv_splits,
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
logit_cap=logits_soft_cap,
sinks=sinks,
xai_temperature_len=layer.xai_temperature_len,
)
return o

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
xai_temperature_len: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
Expand All @@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
kv_splits = tl.load(num_kv_splits + cur_batch)

if xai_temperature_len > 0:
offs_qidx = cur_batch_seq_len - 1
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)

off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d

kv_len_per_split = (
Expand Down Expand Up @@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

if xai_temperature_len > 0:
qk *= xai_temperature_reg

qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))

offs_buf_v = (
Expand Down Expand Up @@ -181,6 +191,7 @@ def _decode_att_m_fwd(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len=-1,
):
BLOCK = 64
# [TODO] work around SGPR limit on MI3xx
Expand Down Expand Up @@ -230,6 +241,7 @@ def _decode_att_m_fwd(
BLOCK_N=BLOCK,
MIN_BLOCK_KV=_MIN_BLOCK_KV,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
num_warps=num_warps,
num_stages=2,
Lk=Lk,
Expand Down Expand Up @@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_H: tl.constexpr,
MIN_BLOCK_KV: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
Expand All @@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
kv_splits = tl.load(num_kv_splits + cur_batch)

if xai_temperature_len > 0:
offs_qidx = cur_batch_seq_len - 1
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
_qtemp = tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale
xai_temperature_reg = tl.where(offs_qidx > xai_temperature_len, _qtemp, 1.0)

offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]

if BLOCK_DPE > 0:
Expand Down Expand Up @@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]

qk = tl.where(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
)
Expand Down Expand Up @@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len=-1,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Expand Down Expand Up @@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_H=BLOCK_H,
MIN_BLOCK_KV=_MIN_BLOCK_KV,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
num_warps=4,
num_stages=num_stages,
Lk=Lk,
Expand Down Expand Up @@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
_decode_att_m_fwd(
q,
Expand All @@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len,
)
_decode_softmax_reducev_fwd(
attn_logits,
Expand Down Expand Up @@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
_decode_grouped_att_m_fwd(
q,
Expand All @@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap,
xai_temperature_len,
)
_decode_softmax_reducev_fwd(
attn_logits,
Expand Down Expand Up @@ -702,6 +730,7 @@ def decode_attention_fwd(
sm_scale,
logit_cap=0.0,
sinks=None,
xai_temperature_len=-1,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
Expand All @@ -725,6 +754,7 @@ def decode_attention_fwd(
sm_scale,
logit_cap=logit_cap,
sinks=sinks,
xai_temperature_len=xai_temperature_len,
)
else:
# GQA/MQA/MLA
Expand All @@ -742,4 +772,5 @@ def decode_attention_fwd(
sm_scale,
logit_cap=logit_cap,
sinks=sinks,
xai_temperature_len=xai_temperature_len,
)
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _fwd_kernel(
stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
Expand Down Expand Up @@ -109,6 +110,15 @@ def _fwd_kernel(
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv

if xai_temperature_len > 0:
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
xai_temperature_reg = tl.where(
offs_qidx > xai_temperature_len,
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
1.0,
)

offs_q = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
Expand Down Expand Up @@ -203,6 +213,9 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]

qk = tl.where(final_mask, qk, float("-inf"))

row_max = tl.max(qk, 1)
Expand Down Expand Up @@ -306,6 +319,9 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]

qk = tl.where(final_mask, qk, float("-inf"))

row_max = tl.max(qk, 1)
Expand Down Expand Up @@ -373,6 +389,7 @@ def extend_attention_fwd(
sliding_window_size=-1,
sinks=None,
window_kv_offsets=None,
xai_temperature_len=-1,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
Expand Down Expand Up @@ -477,6 +494,7 @@ def extend_attention_fwd(
v_buffer.stride(1),
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
Expand Down
94 changes: 94 additions & 0 deletions python/sglang/srt/layers/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,97 @@ def gelu_and_mul_triton(
return out_hidden_states, out_scales
else:
return out_hidden_states, None


# silu on first half of vector
@triton.jit
def silu_and_mul_kernel(
out_hidden_states_ptr, # (bs, hidden_dim)
out_scales_ptr, # (bs,)
hidden_states_ptr, # (bs, hidden_dim * 2)
quant_max: tl.constexpr,
static_scale: tl.constexpr,
hidden_dim: tl.constexpr, # the output hidden_dim
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)

input_start = pid * hidden_dim * 2
output_start = pid * hidden_dim

input1_offs = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
output_offs = tl.arange(0, BLOCK_SIZE)

x1 = tl.load(
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
).to(tl.float32)
x3 = tl.load(
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
).to(tl.float32)

# silu
# cast down before mul to better match training?
silu_x1 = x1 * tl.sigmoid(x1)
out = x3 * silu_x1.to(hidden_states_ptr.dtype.element_ty)

if quant_max is not None:
raise NotImplementedError()

tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)


def silu_and_mul_triton(
hidden_states,
scales=None,
quantize=None, # dtype to quantize to
out=None,
):
bs, in_hidden_dim = hidden_states.shape
hidden_dim = in_hidden_dim // 2

if out is None:
out_hidden_states = torch.empty(
(bs, hidden_dim),
dtype=quantize or hidden_states.dtype,
device=hidden_states.device,
)
else:
assert out.shape == (bs, hidden_dim)
assert out.dtype == (quantize or hidden_states.dtype)
out_hidden_states = out
out_scales = None
static_scale = False
if quantize is not None:
if scales is None:
out_scales = torch.empty(
(bs,), dtype=torch.float32, device=hidden_states.device
)
else:
out_scales = scales
static_scale = True

max_warps = 16 if _is_hip else 32
config = {
# 8 ele per thread (not tuned)
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), max_warps), 4
),
}

silu_and_mul_kernel[(bs,)](
out_hidden_states,
out_scales,
hidden_states,
quant_max=torch.finfo(quantize).max if quantize is not None else None,
static_scale=static_scale,
hidden_dim=hidden_dim,
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
**config,
)

if quantize is not None:
return out_hidden_states, out_scales
else:
return out_hidden_states, None
Loading
Loading