Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ dev = ["sglang[test]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
"srt/speculative/cpp_lookahead/*.cpp",
"srt/speculative/cpp_lookahead/*.h",
]

[tool.setuptools.packages.find]
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,8 @@ def sample_sharegpt_requests(
add_generation_prompt=True,
tokenize=False,
)
prompt = prompt.replace(tokenizer.bos_token, "")
if tokenizer.bos_token:
prompt = prompt.replace(tokenizer.bos_token, "")

prompt_token_ids = tokenizer.encode(prompt)
completion = dataset[i][1]
Expand Down
53 changes: 39 additions & 14 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
Expand Down Expand Up @@ -317,7 +318,9 @@ def init_forward_metadata_capture_cuda_graph(
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
Expand Down Expand Up @@ -422,7 +425,9 @@ def init_forward_metadata_replay_cuda_graph(
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
Expand Down Expand Up @@ -638,7 +643,9 @@ def update(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
Expand All @@ -651,7 +658,9 @@ def update_single_wrapper(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
Expand All @@ -673,7 +682,9 @@ def update_sliding_window(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
assert self.sliding_window_size is not None
for wrapper_id in range(2):
Expand Down Expand Up @@ -721,7 +732,9 @@ def update_cross_attention(
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -753,7 +766,9 @@ def call_begin_forward(
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
):
Expand Down Expand Up @@ -858,7 +873,9 @@ def update(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
Expand All @@ -873,7 +890,9 @@ def update_single_wrapper(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
Expand Down Expand Up @@ -909,7 +928,9 @@ def update_sliding_window(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -955,7 +976,9 @@ def update_cross_attention(
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
):
for wrapper_id in range(2):
if wrapper_id == 0:
Expand Down Expand Up @@ -997,7 +1020,9 @@ def call_begin_forward(
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False,
):
bs = len(seq_lens)
Expand All @@ -1024,8 +1049,8 @@ def call_begin_forward(
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
Expand Down
15 changes: 12 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Expand Down Expand Up @@ -950,7 +951,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):

# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
] = None

# Whether to return hidden states
return_hidden_states: bool = False
Expand Down Expand Up @@ -1600,7 +1603,11 @@ def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)

if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_lookahead()
):
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
Expand Down Expand Up @@ -1975,7 +1982,9 @@ class ModelWorkerBatch:

# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_info: Optional[
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
Expand Down
25 changes: 18 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,18 @@ def __init__(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker

self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None

Expand Down Expand Up @@ -738,8 +750,8 @@ def init_memory_pool_and_cache(self):
else (
server_args.speculative_num_draft_tokens
+ (
server_args.speculative_eagle_topk
* server_args.speculative_num_steps
(server_args.speculative_eagle_topk or 1)
* (server_args.speculative_num_steps or 1)
)
)
)
Expand Down Expand Up @@ -782,7 +794,7 @@ def init_disaggregation(self):
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
if self.draft_worker is None
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
Expand Down Expand Up @@ -819,7 +831,7 @@ def init_disaggregation(self):
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
if self.draft_worker is None
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
Expand Down Expand Up @@ -2356,9 +2368,8 @@ def flush_cache(self):
self.req_to_token_pool.clear()
self.token_to_kv_pool_allocator.clear()

if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
if self.draft_worker:
self.draft_worker.clear_cache_pool()

self.num_generated_tokens = 0
self.forward_ct_decode = 0
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
trace_get_proc_propagate_context,
trace_req_finish,
Expand Down Expand Up @@ -174,6 +175,15 @@ def __init__(
self.image_token_id = self.model_config.image_token_id
self.max_req_input_len = None # Will be set later in engine.py

speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.reserve_input_token_num = (
0
if speculative_algorithm.is_none()
else server_args.speculative_num_draft_tokens
)

if self.model_config.is_multimodal:
import_processors()
try:
Expand Down Expand Up @@ -618,6 +628,7 @@ def _validate_one_request(
_max_req_len = self.context_len

input_token_num = len(input_ids) if input_ids is not None else 0
input_token_num += self.reserve_input_token_num
if input_token_num >= self.context_len:
if self.server_args.allow_auto_truncate:
logger.warning(
Expand Down
25 changes: 25 additions & 0 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(self, model_runner: ModelRunner):
if (
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_lookahead()
):
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
Expand Down Expand Up @@ -440,11 +441,21 @@ def can_run(self, forward_batch: ForwardBatch):
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)

is_lookahead_supported = (
(
forward_batch.batch_size * self.num_tokens_per_bs
== forward_batch.input_ids.numel()
)
if self.model_runner.spec_algorithm.is_lookahead()
else True
)

return (
is_bs_supported
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
and is_lookahead_supported
)

def capture(self) -> None:
Expand Down Expand Up @@ -855,6 +866,20 @@ def get_spec_info(self, num_tokens: int):
seq_lens_cpu=None,
)

elif self.model_runner.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput

spec_info = LookaheadVerifyInput(
draft_token=None,
tree_mask=self.custom_mask,
positions=None,
retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
draft_token_num=self.num_tokens_per_bs,
)
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL

return spec_info


Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def init_memory_pool(
if self.is_hybrid_gdn:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)

if not self.spec_algorithm.is_none():
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
max_num_reqs = self.server_args.max_num_reqs
Expand Down
Loading
Loading