Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/constrained/base_grammar_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def rollback(self, k: int):
raise NotImplementedError()

def is_terminated(self):
raise NotImplementedError()
return False

def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
Expand Down
29 changes: 21 additions & 8 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchMultimodalDecodeReq,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
)
Expand Down Expand Up @@ -60,6 +61,8 @@ class DecodeStatus:
decode_ids: List[int]
surr_offset: int
read_offset: int
# Offset that's sent to tokenizer for incremental update.
sent_offset: int = 0


class DetokenizerManager:
Expand Down Expand Up @@ -151,7 +154,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
self.decode_status[rid] = s
else:
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]
s.decode_ids.extend(recv_obj.decode_ids[i])

read_ids.append(
self.trim_matched_stop(
Expand Down Expand Up @@ -199,13 +202,15 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
else:
new_text = find_printable_text(new_text)

output_strs.append(
self.trim_matched_stop(
s.decoded_text + new_text,
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
output_str = self.trim_matched_stop(
s.decoded_text + new_text,
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
# Incrementally send text.
incremental_output = output_str[s.sent_offset :]
s.sent_offset = len(output_str)
output_strs.append(incremental_output)

return BatchStrOut(
rids=recv_obj.rids,
Expand All @@ -232,7 +237,15 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
)

def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
raise NotImplementedError()
outputs = self.tokenizer.detokenize(recv_obj)
return BatchMultimodalOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
outputs=outputs,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
)


class LimitedCapacityDict(OrderedDict):
Expand Down
43 changes: 37 additions & 6 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
Expand Down Expand Up @@ -436,6 +437,7 @@ def __init__(
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
self.lora_path = lora_path

# Memory pool info
self.req_pool_idx: Optional[int] = None
Expand Down Expand Up @@ -487,6 +489,13 @@ def __init__(
# For retraction
self.is_retracted = False

# Incremental streamining
self.send_token_offset: int = 0
self.send_decode_id_offset: int = 0
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
# because the decode server does not have the first output token logprobs
self.send_output_token_logprobs_offset: int = 0

# Logprobs (arguments)
self.return_logprob = return_logprob
# Start index to compute logprob from.
Expand All @@ -496,11 +505,9 @@ def __init__(
self.temp_scaled_logprobs = False
self.top_p_normalized_logprobs = False

# Latency Breakdown
self.queue_time_start = None
self.queue_time_end = None

# Logprobs (return values)
# True means the input logprob has been already sent to detokenizer.
self.input_logprob_sent: bool = False
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
Expand All @@ -515,8 +522,10 @@ def __init__(
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None

if return_logprob:
# shape: (bs, 1)
self.output_token_logprobs_val = []
self.output_token_logprobs_idx = []
# shape: (bs, k)
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
Expand All @@ -543,7 +552,12 @@ def __init__(
# The number of verification forward passes in the speculative decoding.
# This is used to compute the average acceptance length per request.
self.spec_verify_ct = 0
self.lora_path = lora_path

# For metrics
self.time_stats: TimeStats = TimeStats()
self.has_log_time_stats: bool = False
self.queue_time_start = None
self.queue_time_end = None

# For disaggregation
self.bootstrap_host: str = bootstrap_host
Expand All @@ -562,8 +576,8 @@ def __init__(
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1

self.metadata_buffer_index: int = -1

# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None

Expand Down Expand Up @@ -656,6 +670,11 @@ def check_finished(self):
)
return

if self.grammar is not None:
if self.grammar.is_terminated():
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return

last_token_id = self.output_ids[-1]

if not self.sampling_params.ignore_eos:
Expand Down Expand Up @@ -713,6 +732,18 @@ def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
del self.kv_cache_cpu

def log_time_stats(self):
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
if self.has_log_time_stats is True:
return

if self.bootstrap_room is not None:
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
else:
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
logger.info(f"{prefix}: {self.time_stats}")
self.has_log_time_stats = True

def __repr__(self):
return (
f"Req(rid={self.rid}, "
Expand Down
23 changes: 9 additions & 14 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,6 @@ def init_memory_pool_and_cache(self):
)

def init_metrics(self):
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
Expand Down Expand Up @@ -1122,9 +1118,6 @@ def log_prefill_stats(
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
)

num_new_seq = len(can_run_list)
f = (
Expand Down Expand Up @@ -1601,14 +1594,9 @@ def process_batch_result(
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.set_next_batch_sampling_info_done(batch)

if self.return_health_check_ct:
# Return some signal for the health check.
Expand Down Expand Up @@ -1776,6 +1764,13 @@ def move_ready_grammar_requests(self):
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]

def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()

def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
Expand Down
Loading
Loading