diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py index aa067152fd38..43f3356a8070 100644 --- a/python/sglang/srt/observability/req_time_stats.py +++ b/python/sglang/srt/observability/req_time_stats.py @@ -190,6 +190,21 @@ class RequestStage: metrics_is_observed=True, ) + # speculative decode + SPEC_DRAFT = RequestStageConfig( + "spec_draft", + level=2, + ) + + SPEC_VERIFY = RequestStageConfig( + "spec_verify", + level=2, + ) + + SPEC_DRAFT_EXTEND = RequestStageConfig( + "spec_draft_extend", + level=3, + ) # other ANONYMOUS = RequestStageConfig("") @@ -551,6 +566,11 @@ class SchedulerReqTimeStats(ReqTimeStatsBase): last_forward_entry_time: float = 0.0 last_prefill_finished_time: float = 0.0 + # speculative decoding + spec_draft_start_time: float = 0.0 + spec_verify_start_time: float = 0.0 + spec_draft_extend_start_time: float = 0.0 + # other transfer_speed_gb_s: float = 0.0 transfer_total_mb: float = 0.0 @@ -577,6 +597,42 @@ def set_scheduler_recv_time(self, ts=None): ts = ts or time.perf_counter() self.scheduler_recv_time = ts + def set_spec_draft_start_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.spec_draft_start_time = ts + + def set_spec_draft_end_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + + stage = RequestStage.SPEC_DRAFT + self.trace_slice(stage, self.spec_draft_start_time, ts) + + def set_spec_verify_start_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.spec_verify_start_time = ts + + def set_spec_verify_end_time(self, ts=None, accepted_tokens: int = 0): + if ts is None: + ts = time.perf_counter() + stage = RequestStage.SPEC_VERIFY + self.trace_slice( + stage, self.spec_verify_start_time, ts, {"accepted_tokens": accepted_tokens} + ) + + def set_spec_draft_extend_start_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.spec_draft_extend_start_time = ts + + def set_spec_draft_extend_end_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + stage = RequestStage.SPEC_DRAFT_EXTEND + self.trace_slice(stage, self.spec_draft_extend_start_time, ts) + def set_retract_time(self, ts=None): ts = ts or time.perf_counter() # retract @@ -1049,9 +1105,11 @@ def set_schedule_time_batch(batch: ScheduleBatch): req.time_stats.set_last_scheduled_time(batch.forward_mode, ts, _attrs) -def set_time_batch(reqs: List[Any], set_func: str): +def set_time_batch(reqs: List[Any], set_func: str, trace_only: bool = False): if reqs is None or len(reqs) == 0: return + if trace_only and not get_global_tracing_enabled(): + return ts = time.perf_counter() for req in reqs: diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 59c63c17ca5e..333ed0038685 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -29,6 +29,8 @@ ForwardBatch, ForwardMode, ) +from sglang.srt.observability.req_time_stats import set_time_batch +from sglang.srt.observability.trace import get_global_tracing_enabled from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.draft_utils import DraftBackendFactory from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( @@ -312,14 +314,29 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul can_run_cuda_graph=can_run_cuda_graph, ) else: + set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True) + with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): spec_info = self.draft(batch) + + set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) + set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) + logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( self.verify(batch, spec_info) ) + if get_global_tracing_enabled(): + for idx, req in enumerate(batch.reqs): + accepted = verify_output.accept_length_per_req_cpu[idx] + req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted) + + set_time_batch( + batch.reqs, "set_spec_draft_extend_start_time", trace_only=True + ) + with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): @@ -332,6 +349,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul # decode is not finished self.forward_draft_extend_after_decode(batch) + set_time_batch( + batch.reqs, "set_spec_draft_extend_end_time", trace_only=True + ) + return GenerationBatchResult( logits_output=logits_output, next_token_ids=verify_output.verified_id, diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 4c0e79503170..d72eff2a15a8 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -10,6 +10,8 @@ from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.observability.req_time_stats import set_time_batch +from sglang.srt.observability.trace import get_global_tracing_enabled from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.cpp_ngram.ngram_corpus import NgramCorpus from sglang.srt.speculative.ngram_info import NgramVerifyInput @@ -250,7 +252,12 @@ def _update_ngram_corpus(self, batch: ScheduleBatch): self.ngram_corpus.batch_put(batch_tokens) def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult: + set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True) + self._prepare_for_speculative_decoding(batch) + + set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) + model_worker_batch = batch.get_model_worker_batch() spec_info = model_worker_batch.spec_info num_accepted_tokens = 0 @@ -265,6 +272,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul spec_info.retrive_next_token.shape ).cpu() + set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) + batch_result = self.target_worker.forward_batch_generation( model_worker_batch, is_verify=True ) @@ -298,6 +307,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul batch, logits_output, self.page_size, vocab_mask ) accept_length_per_req_cpu = verify_input.accept_length.cpu().tolist() + + if get_global_tracing_enabled(): + for idx, req in enumerate(batch.reqs): + accepted = ( + verify_input.accept_length[idx].item() + if verify_input.accept_length is not None + else 0 + ) + req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted) + # Store accept_lens for per-request metrics accept_lens = verify_input.accept_length if batch.return_logprob: