Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
prepare_abort,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
Expand Down Expand Up @@ -253,6 +253,7 @@ def add(self, req: Req, is_retracted: bool = False) -> None:
prefill_dp_rank=req.data_parallel_rank,
)

req.add_latency(RequestStage.DECODE_PREPARE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the prepare stage of the request, would it be better to use create_time as last_tic? This way we can include the tokenizer time in the metrics.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fine grained stat the time cost in tokenizer, prefill, detokenizer stages, I think it's better to separate theses staged, the metrics added in this PR only care about the latencies in prefill and decode stages.

self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
Expand Down Expand Up @@ -421,6 +422,7 @@ def pop_preallocated(self) -> List[DecodeRequest]:
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)

Expand Down Expand Up @@ -662,6 +664,7 @@ def pop_transferred(self) -> List[Req]:
for i in indices_to_remove:
idx = self.queue[i].metadata_buffer_index
assert idx != -1
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
self.req_to_metadata_buffer_idx_allocator.free(idx)

self.queue = [
Expand Down Expand Up @@ -853,6 +856,7 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
# we can only add at least `num_not_used_batch` new batch to the running queue
if i < num_not_used_batch:
can_run_list.append(req)
req.add_latency(RequestStage.DECODE_WAITING)
req.init_next_round_input(self.tree_cache)
else:
waiting_queue.append(req)
Expand Down
12 changes: 11 additions & 1 deletion python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
FINISH_LENGTH,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
Expand Down Expand Up @@ -170,6 +175,7 @@ def add(self, req: Req, num_kv_heads: int) -> None:
pp_rank=self.pp_rank,
)
self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req)

def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
Expand Down Expand Up @@ -256,6 +262,8 @@ def pop_bootstrapped(

num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)

req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)

Expand Down Expand Up @@ -404,6 +412,7 @@ def process_batch_result_disagg_prefill(
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
self.disagg_prefill_inflight_queue.append(req)
if (
logits_output is not None
Expand Down Expand Up @@ -539,6 +548,7 @@ def process_disagg_prefill_inflight_queue(
)
for req in done_reqs:
req: Req
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1

Expand Down
35 changes: 34 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import enum

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this below copyright

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +37,7 @@
import dataclasses
import logging
import threading
import time
from enum import Enum, auto
from http import HTTPStatus
from itertools import chain
Expand All @@ -61,7 +64,7 @@
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
Expand Down Expand Up @@ -408,6 +411,23 @@ def merge(self, other: MultimodalInputs):
# other args would be kept intact


class RequestStage(str, enum.Enum):
# prefill
PREFILL_WAITING = "prefill_waiting"

# disaggregation prefill
PREFILL_PREPARE = "prefill_prepare"
PREFILL_BOOTSTRAP = "prefill_bootstrap"
PREFILL_FORWARD = "prefill_forward"
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

# disaggregation decode
DECODE_PREPARE = "decode_prepare"
DECODE_BOOTSTRAP = "decode_bootstrap"
DECODE_WAITING = "decode_waiting"
DECODE_TRANSFERRED = "decode_transferred"


class Req:
"""The input and output status of a request."""

Expand All @@ -434,6 +454,7 @@ def __init__(
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
# Input and output info
self.rid = rid
Expand Down Expand Up @@ -591,10 +612,12 @@ def __init__(
self.spec_verify_ct = 0

# For metrics
self.metrics_collector = metrics_collector
self.time_stats: TimeStats = TimeStats()
self.has_log_time_stats: bool = False
self.queue_time_start = None
self.queue_time_end = None
self.last_tic = time.monotonic()

# For disaggregation
self.bootstrap_host: str = bootstrap_host
Expand Down Expand Up @@ -627,6 +650,16 @@ def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0

def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
now = time.monotonic()
self.metrics_collector.observe_request_latency_seconds(
stage.value, now - self.last_tic
)
self.last_tic = now

def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
FINISH_ABORT,
MultimodalInputs,
Req,
RequestStage,
ScheduleBatch,
global_server_args_dict,
)
Expand Down Expand Up @@ -1196,6 +1197,9 @@ def handle_generate_request(
bootstrap_room=recv_req.bootstrap_room,
data_parallel_rank=recv_req.data_parallel_rank,
vocab_size=self.model_config.vocab_size,
metrics_collector=(
self.metrics_collector if self.enable_metrics else None
),
)
req.tokenizer = self.tokenizer

Expand Down Expand Up @@ -1734,6 +1738,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# only record queue time when enable_metrics is True to avoid overhead
for req in can_run_list:
req.queue_time_end = time.perf_counter()
req.add_latency(RequestStage.PREFILL_WAITING)

self.waiting_queue = [
x for x in self.waiting_queue if x not in set(can_run_list)
Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/metrics/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from enum import Enum
from typing import Dict, List, Optional, Union

from sglang.srt.metrics.utils import generate_buckets
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var

Expand Down Expand Up @@ -514,6 +514,14 @@ def __init__(self, labels: Dict[str, str]) -> None:
buckets=tree_traversal_time_buckets,
)

self.request_latency_seconds = Histogram(
name="sglang:request_latency_seconds",
documentation="The latency of each stage of requests.",
# captures latency in range [1ms - ~1191s]
buckets=exponential_buckets(start=0.001, width=1.62, length=30),
labelnames=list(labels.keys()) + ["stage"],
)

def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)
Expand All @@ -527,6 +535,10 @@ def increment_bootstrap_failed_reqs(self) -> None:
def increment_transfer_failed_reqs(self) -> None:
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)

def observe_request_latency_seconds(self, stage: str, latency: float) -> None:
labels_with_stage = {**self.labels, "stage": stage}
self.request_latency_seconds.labels(**labels_with_stage).observe(latency)

def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
Expand Down
9 changes: 2 additions & 7 deletions python/sglang/srt/metrics/func_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from functools import wraps
from typing import Any, Callable, List, Optional

from sglang.srt.metrics.utils import exponential_buckets

enable_metrics = False


Expand All @@ -42,13 +44,6 @@ def enable_func_timer():
FUNC_LATENCY = None


def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets


def time_func_latency(
func: Callable = None, name: Optional[str] = None
) -> Callable[..., Any]:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ def generate_buckets(
return sorted(set(default_buckets))
assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]]))


def exponential_buckets(start: float, width: float, length: int) -> List[float]:
buckets = []
for i in range(length):
buckets.append(start * (width**i))
return buckets
Loading