diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index d37575dcf0aa..f336cac9fd3e 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -6,6 +6,7 @@ import numpy as np import numpy.typing as npt +from sglang.srt.metrics.collector import SchedulerMetricsCollector from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -52,6 +53,7 @@ def __init__( disaggregation_mode: DisaggregationMode, server_args: ServerArgs, is_mla_backend: Optional[bool] = False, + scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None, ): ... diff --git a/python/sglang/srt/disaggregation/common/conn.py b/python/sglang/srt/disaggregation/common/conn.py index da6cc7217849..846044cb4f7d 100644 --- a/python/sglang/srt/disaggregation/common/conn.py +++ b/python/sglang/srt/disaggregation/common/conn.py @@ -22,6 +22,7 @@ KVPoll, ) from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.metrics.collector import SchedulerMetricsCollector from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, @@ -42,6 +43,7 @@ def __init__( disaggregation_mode: DisaggregationMode, server_args: ServerArgs, is_mla_backend: Optional[bool] = False, + scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None, ): self.kv_args = args self.is_mla_backend = is_mla_backend diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 9e35078e7781..3df5ea24b575 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -33,13 +33,17 @@ group_concurrent_contiguous, ) from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine -from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, + KVCacheTransferLatencyMonitor, +) from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, ) +from sglang.srt.metrics.collector import SchedulerMetricsCollector from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, @@ -144,6 +148,7 @@ def __init__( disaggregation_mode: DisaggregationMode, server_args: ServerArgs, is_mla_backend: Optional[bool] = False, + scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None, ): self.kv_args = args self.local_ip = get_local_ip_auto() @@ -171,6 +176,7 @@ def __init__( if is_valid_ipv6_address(self.local_ip): self.server_socket.setsockopt(zmq.IPV6, 1) + self.scheduler_metrics_collector = scheduler_metrics_collector self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} @@ -211,9 +217,14 @@ def __init__( "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300 ) + self.kvcache_transfer_latency_monitor = KVCacheTransferLatencyMonitor( + self.scheduler_metrics_collector + ) + self.enable_custom_mem_pool = get_bool_env_var( "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" ) + elif self.disaggregation_mode == DisaggregationMode.DECODE: self.heartbeat_failures = {} self.session_pool = defaultdict(requests.Session) @@ -649,6 +660,11 @@ def transfer_worker( target_rank_registration_info: KVArgsRegisterInfo = ( self.decode_kv_args_table[req.mooncake_session_id] ) + + self.kvcache_transfer_latency_monitor.collect_begin_timestamp( + req.room, req.mooncake_session_id + ) + if self.is_mla_backend or ( self.attn_tp_size == target_rank_registration_info.dst_attn_tp_size @@ -671,6 +687,11 @@ def transfer_worker( target_rank_registration_info.dst_kv_item_len, executor, ) + + self.kvcache_transfer_latency_monitor.collect_finish_timestamp( + req.room, req.mooncake_session_id, ret + ) + if ret != 0: with self.session_lock: self.session_failures[req.mooncake_session_id] += 1 @@ -695,6 +716,9 @@ def transfer_worker( break if kv_chunk.is_last: + self.kvcache_transfer_latency_monitor.record( + req.room, req.mooncake_session_id + ) # Only the last chunk we need to send the aux data ret = self.send_aux( req.mooncake_session_id, @@ -727,6 +751,7 @@ def transfer_worker( ): if kv_chunk.room in self.transfer_infos: self.transfer_infos.pop(kv_chunk.room) + self.kvcache_transfer_latency_monitor.pop_room(kv_chunk.room) except Exception as e: # NOTE(shangming): Remove this when we make sure the transfer thread is bug-free diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 7a75d79b740d..a25523ee636d 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -26,6 +26,7 @@ ) from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.metrics.collector import SchedulerMetricsCollector from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( format_tcp_address, @@ -117,6 +118,7 @@ def __init__( disaggregation_mode: DisaggregationMode, server_args: ServerArgs, is_mla_backend: Optional[bool] = False, + scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None, ): super().__init__(args, disaggregation_mode, server_args, is_mla_backend) try: diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 72cf9d3f953e..5c987dceba49 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -36,6 +36,7 @@ ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, + get_metrics_collector, is_mla_backend, kv_to_page_indices, kv_to_page_num, @@ -97,6 +98,7 @@ def __init__( self.max_total_num_tokens = max_total_num_tokens self.scheduler = scheduler self.transfer_backend = transfer_backend + self.metrics_collector = get_metrics_collector(self.scheduler) self.kv_manager = self._init_kv_manager() def _init_kv_manager(self) -> BaseKVManager: @@ -140,6 +142,7 @@ def _init_kv_manager(self) -> BaseKVManager: DisaggregationMode.PREFILL, self.scheduler.server_args, self.is_mla_backend, + self.metrics_collector, ) return kv_manager diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 720c9d5a59e9..10c2c2558790 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -4,11 +4,12 @@ import os import random import threading +import time import warnings from collections import deque from contextlib import nullcontext from enum import Enum -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional import numpy as np import requests @@ -19,6 +20,8 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req + from sglang.srt.managers.scheduler import Scheduler + from sglang.srt.managers.scheduler_metrics_mixin import SchedulerMetricsCollector ######################### # Constants & Enums @@ -347,6 +350,62 @@ def register_disaggregation_server( ) +######################### +# Monitor Metrics +######################### + + +def get_metrics_collector(scheduler: Scheduler): + if scheduler.enable_metrics: + return scheduler.metrics_collector + return None + + +class KVCacheTransferLatencyMonitor: + def __init__( + self, + metrics_collector: Optional[SchedulerMetricsCollector] = None, + ): + self.metrics_collector = metrics_collector + self.kvcache_transfer_latency_table: Dict[int, Dict[str, float]] = {} + + def collect_begin_timestamp(self, room: int, dst_id: str): + if self.metrics_collector is None: + return + + if room not in self.kvcache_transfer_latency_table: + self.kvcache_transfer_latency_table[room] = {dst_id: -time.time()} + else: + if dst_id not in self.kvcache_transfer_latency_table[room]: + self.kvcache_transfer_latency_table[room][dst_id] = 0 + + self.kvcache_transfer_latency_table[room][dst_id] -= time.time() + + def collect_finish_timestamp(self, room: int, dst_id: str, ret: int): + if self.metrics_collector is None: + return + + assert room in self.kvcache_transfer_latency_table + assert dst_id in self.kvcache_transfer_latency_table[room] + if ret != 0: + self.kvcache_transfer_latency_table[room][dst_id] = 0 + else: + self.kvcache_transfer_latency_table[room][dst_id] += time.time() + + def record(self, room: int, dst_id: str): + if self.metrics_collector is not None: + self.metrics_collector.observe_kvcache_transfer_latency( + self.kvcache_transfer_latency_table[room].pop(dst_id) + ) + + def pop_room(self, room: int): + if ( + self.metrics_collector is not None + and room in self.kvcache_transfer_latency_table + ): + self.kvcache_transfer_latency_table.pop(room) + + ######################### # Misc ######################### diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index a6497ffde5c1..777eb5440b65 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -48,7 +48,10 @@ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): } if dp_rank is not None: labels["dp_rank"] = dp_rank - self.metrics_collector = SchedulerMetricsCollector(labels=labels) + self.metrics_collector = SchedulerMetricsCollector( + labels=labels, + bucket_kvcache_transfer_latency=self.server_args.bucket_kvcache_transfer_latency, + ) def init_kv_events(self, kv_events_config: Optional[str]): if self.enable_kv_cache_events: diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 4c32b8fc6348..19b5e6fe93bf 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -146,13 +146,18 @@ class SchedulerStats: num_decode_prealloc_queue_reqs: int = 0 num_decode_transfer_queue_reqs: int = 0 total_retracted_reqs: int = 0 + kvcache_transfer_latency: float = 0.0 class SchedulerMetricsCollector: - def __init__(self, labels: Dict[str, str]) -> None: + def __init__( + self, + labels: Dict[str, str], + bucket_kvcache_transfer_latency: Optional[List[float]] = None, + ) -> None: # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` - from prometheus_client import Counter, Gauge + from prometheus_client import Counter, Gauge, Histogram self.labels = labels self.last_log_time = time.perf_counter() @@ -268,10 +273,50 @@ def __init__(self, labels: Dict[str, str]) -> None: labelnames=labels.keys(), ) + if bucket_kvcache_transfer_latency is None: + bucket_kvcache_transfer_latency = [ + 0.001, + 0.002, + 0.004, + 0.006, + 0.008, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.2, + 0.4, + 0.6, + 0.8, + 1, + 2, + 4, + 6, + 8, + 10, + 20, + 40, + 60, + 80, + 100, + ] + + self.histogram_kvcache_transfer_latency = Histogram( + name="sglang:kvcache_transfer_latency", + documentation="Histogram of kvcache transfer latency in seconds.", + labelnames=labels.keys(), + buckets=bucket_kvcache_transfer_latency, + ) + def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) + def observe_kvcache_transfer_latency(self, value: float) -> None: + self.histogram_kvcache_transfer_latency.labels(**self.labels).observe(value) + def increment_bootstrap_failed_reqs(self) -> None: self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fd2bd1580b32..6ad2d0f34703 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -117,6 +117,7 @@ class ServerArgs: bucket_time_to_first_token: Optional[List[float]] = None bucket_inter_token_latency: Optional[List[float]] = None bucket_e2e_request_latency: Optional[List[float]] = None + bucket_kvcache_transfer_latency: Optional[List[float]] = None collect_tokens_histogram: bool = False decode_log_interval: int = 40 enable_request_time_stats_logging: bool = False @@ -1156,6 +1157,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.bucket_e2e_request_latency, help="The buckets of end-to-end request latency, specified as a list of floats.", ) + parser.add_argument( + "--bucket-kvcache-transfer-latency", + type=float, + nargs="+", + default=ServerArgs.bucket_kvcache_transfer_latency, + help="The buckets of request kvcache transfer latency, specified as a list of floats.", + ) parser.add_argument( "--collect-tokens-histogram", action="store_true",