Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion python/sglang/srt/observability/req_time_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ class RequestStage:
metrics_is_observed=True,
)

# speculative decode
SPEC_DRAFT = RequestStageConfig(
"spec_draft",
level=2,
)

SPEC_VERIFY = RequestStageConfig(
"spec_verify",
level=2,
)

SPEC_DRAFT_EXTEND = RequestStageConfig(
"spec_draft_extend",
level=3,
)
# other
ANONYMOUS = RequestStageConfig("")

Expand Down Expand Up @@ -548,6 +563,11 @@ class SchedulerReqTimeStats(ReqTimeStatsBase):
last_forward_entry_time: float = 0.0
last_prefill_finished_time: float = 0.0

# speculative decoding
spec_draft_start_time: float = 0.0
spec_verify_start_time: float = 0.0
spec_draft_extend_start_time: float = 0.0

# other
transfer_speed_gb_s: float = 0.0
transfer_total_mb: float = 0.0
Expand Down Expand Up @@ -575,6 +595,42 @@ def set_scheduler_recv_time(self, ts=None):
ts = time.perf_counter()
self.scheduler_recv_time = ts

def set_spec_draft_start_time(self, ts=None):
if ts is None:
ts = time.perf_counter()
self.spec_draft_start_time = ts

def set_spec_draft_end_time(self, ts=None):
if ts is None:
ts = time.perf_counter()

stage = RequestStage.SPEC_DRAFT
self.trace_slice(stage, self.spec_draft_start_time, ts)

def set_spec_verify_start_time(self, ts=None):
if ts is None:
ts = time.perf_counter()
self.spec_verify_start_time = ts

def set_spec_verify_end_time(self, ts=None, accepted_tokens: int = 0):
if ts is None:
ts = time.perf_counter()
stage = RequestStage.SPEC_VERIFY
self.trace_slice(
stage, self.spec_verify_start_time, ts, {"accepted_tokens": accepted_tokens}
)

def set_spec_draft_extend_start_time(self, ts=None):
if ts is None:
ts = time.perf_counter()
self.spec_draft_extend_start_time = ts

def set_spec_draft_extend_end_time(self, ts=None):
if ts is None:
ts = time.perf_counter()
stage = RequestStage.SPEC_DRAFT_EXTEND
self.trace_slice(stage, self.spec_draft_extend_start_time, ts)

def set_retract_time(self, ts=None):
if ts is None:
ts = time.perf_counter()
Expand Down Expand Up @@ -1063,9 +1119,11 @@ def set_schedule_time_batch(batch: ScheduleBatch):
req.time_stats.set_last_scheduled_time(batch.forward_mode, ts, _attrs)


def set_time_batch(reqs: List[Any], set_func: str):
def set_time_batch(reqs: List[Any], set_func: str, trace_only: bool = False):
if reqs is None or len(reqs) == 0:
return
if trace_only and not get_global_tracing_enabled():
return

ts = time.perf_counter()
for req in reqs:
Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
ForwardBatch,
ForwardMode,
)
from sglang.srt.observability.req_time_stats import set_time_batch
from sglang.srt.observability.trace import get_global_tracing_enabled
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.draft_utils import DraftBackendFactory
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
Expand Down Expand Up @@ -312,14 +314,29 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
can_run_cuda_graph=can_run_cuda_graph,
)
else:
set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True)

with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
spec_info = self.draft(batch)

set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)
set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)

logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)

if get_global_tracing_enabled():
for idx, req in enumerate(batch.reqs):
accepted = verify_output.accept_length_per_req_cpu[idx]
req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted)

set_time_batch(
batch.reqs, "set_spec_draft_extend_start_time", trace_only=True
)

with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
Expand All @@ -332,6 +349,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
# decode is not finished
self.forward_draft_extend_after_decode(batch)

set_time_batch(
batch.reqs, "set_spec_draft_extend_end_time", trace_only=True
)

return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=verify_output.verified_id,
Expand Down
19 changes: 19 additions & 0 deletions python/sglang/srt/speculative/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.observability.req_time_stats import set_time_batch
from sglang.srt.observability.trace import get_global_tracing_enabled
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_corpus import NgramCorpus
from sglang.srt.speculative.ngram_info import NgramVerifyInput
Expand Down Expand Up @@ -215,7 +217,12 @@ def _update_ngram_corpus(self, batch: ScheduleBatch):
self.ngram_corpus.batch_put(batch_tokens)

def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True)

self._prepare_for_speculative_decoding(batch)

set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)

model_worker_batch = batch.get_model_worker_batch()
spec_info = model_worker_batch.spec_info
num_accepted_tokens = 0
Expand All @@ -229,6 +236,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
spec_info.retrive_next_token.shape
).cpu()

set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)

batch_result = self.target_worker.forward_batch_generation(
model_worker_batch, is_verify=True
)
Expand Down Expand Up @@ -261,6 +270,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
batch, logits_output, self.page_size, vocab_mask
)

if get_global_tracing_enabled():
for idx, req in enumerate(batch.reqs):
accepted = (
verify_input.accept_length[idx].item()
if verify_input.accept_length is not None
else 0
)
req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted)

# Store accept_lens for per-request metrics
accept_lens = verify_input.accept_length
if batch.return_logprob:
Expand Down
Loading