Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8f2208a
add the KVCache transfer latency metric
SCDESPERTATE Jul 10, 2025
7b2a6c8
Merge branch 'sgl-project:main' into add_send_kvcache_lat_metric
SCDESPERTATE Jul 11, 2025
bd1b6aa
format the changes
SCDESPERTATE Jul 11, 2025
291399c
pass `metrics_collector` on demand
SCDESPERTATE Jul 11, 2025
6fd9f1a
Merge branch 'main' into add_send_kvcache_lat_metric
SCDESPERTATE Jul 11, 2025
d7ff04d
avoid circular import
SCDESPERTATE Jul 11, 2025
342cc6c
Merge branch 'main' into add_send_kvcache_lat_metric
SCDESPERTATE Jul 25, 2025
763c2b4
update all `KVManager.__init__` interfaces to accommodate the collector
SCDESPERTATE Jul 25, 2025
9118c71
update `BaseKVManager.__init__` interface to receiver collector
SCDESPERTATE Jul 25, 2025
76039d3
tidy the timestamp_record process inside the `KVManager.transfer_worker`
SCDESPERTATE Jul 28, 2025
e1a67a0
fix a typo
SCDESPERTATE Jul 28, 2025
face952
Merge branch 'main' into add_send_kvcache_lat_metric
stmatengss Aug 4, 2025
f4e3742
follow the change in PR#8483
SCDESPERTATE Aug 4, 2025
325d56a
a follow-up of the last commit
SCDESPERTATE Aug 4, 2025
3c01861
Merge branch 'main' into add_send_kvcache_lat_metric
stmatengss Aug 4, 2025
1d8e524
abstract kvcache_latency metric and its operations into a class
SCDESPERTATE Aug 4, 2025
8037779
fix a typo in utils.py
SCDESPERTATE Aug 5, 2025
7d77b50
Merge branch 'main' into add_send_kvcache_lat_metric
SCDESPERTATE Aug 5, 2025
6ee4e17
Merge branch 'main' into add_send_kvcache_lat_metric
stmatengss Aug 5, 2025
d2111b1
fix lint error
SCDESPERTATE Aug 5, 2025
2466bf7
Merge branch 'main' into add_send_kvcache_lat_metric
stmatengss Aug 5, 2025
e450144
Merge branch 'main' into add_send_kvcache_lat_metric
stmatengss Aug 7, 2025
1db01fa
Merge branch 'main' into add_send_kvcache_lat_metric
SCDESPERTATE Aug 12, 2025
7988552
Merge branch 'main' into add_send_kvcache_lat_metric
SCDESPERTATE Aug 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import numpy.typing as npt

from sglang.srt.metrics.collector import SchedulerMetricsCollector
from sglang.srt.server_args import ServerArgs

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the other pr. don't pass collector into kv manager

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a deep analysis, I find it difficult to achieve a fine-grained latency tracking without passing the collector to the KVManager && the coordination inside the conn.py🤔 The purpose of this PR is to track the KV transfer latency of the exact network stack, reflecting the real-time network performance. However, if the timestamp collecting is only allowed in the prefill.py, though generality is preserved, other irrelevant latencies like request queueing, scheduler dispatching and result polling would be included in this metric, which would mislead the operators. Hence, I think passing the collector into the KVManager in this case is quite necessary.

): ...


Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.metrics.collector import SchedulerMetricsCollector
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
Expand All @@ -42,6 +43,7 @@ def __init__(
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
self.kv_args = args
self.is_mla_backend = is_mla_backend
Expand Down
27 changes: 26 additions & 1 deletion python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@
group_concurrent_contiguous,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVCacheTransferLatencyMonitor,
)
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.metrics.collector import SchedulerMetricsCollector
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
Expand Down Expand Up @@ -144,6 +148,7 @@ def __init__(
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
self.kv_args = args
self.local_ip = get_local_ip_auto()
Expand Down Expand Up @@ -171,6 +176,7 @@ def __init__(
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)

self.scheduler_metrics_collector = scheduler_metrics_collector
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
Expand Down Expand Up @@ -211,9 +217,14 @@ def __init__(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
)

self.kvcache_transfer_latency_monitor = KVCacheTransferLatencyMonitor(
self.scheduler_metrics_collector
)

self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)

elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session)
Expand Down Expand Up @@ -649,6 +660,11 @@ def transfer_worker(
target_rank_registration_info: KVArgsRegisterInfo = (
self.decode_kv_args_table[req.mooncake_session_id]
)

self.kvcache_transfer_latency_monitor.collect_begin_timestamp(
req.room, req.mooncake_session_id
)

if self.is_mla_backend or (
self.attn_tp_size
== target_rank_registration_info.dst_attn_tp_size
Expand All @@ -671,6 +687,11 @@ def transfer_worker(
target_rank_registration_info.dst_kv_item_len,
executor,
)

self.kvcache_transfer_latency_monitor.collect_finish_timestamp(
req.room, req.mooncake_session_id, ret
)

if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
Expand All @@ -695,6 +716,9 @@ def transfer_worker(
break

if kv_chunk.is_last:
self.kvcache_transfer_latency_monitor.record(
req.room, req.mooncake_session_id
)
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
Expand Down Expand Up @@ -727,6 +751,7 @@ def transfer_worker(
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)
self.kvcache_transfer_latency_monitor.pop_room(kv_chunk.room)

except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you don't implement the same latency monitor for nixl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. So far, I'm not quite familiar with the NIXL part code, may be later I would go through that then add support for it. But for interface compatibility, the collector field has to be added, otherwise exception would be raised.

from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.metrics.collector import SchedulerMetricsCollector
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
scheduler_metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
try:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
get_metrics_collector,
is_mla_backend,
kv_to_page_indices,
kv_to_page_num,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
self.max_total_num_tokens = max_total_num_tokens
self.scheduler = scheduler
self.transfer_backend = transfer_backend
self.metrics_collector = get_metrics_collector(self.scheduler)
self.kv_manager = self._init_kv_manager()

def _init_kv_manager(self) -> BaseKVManager:
Expand Down Expand Up @@ -140,6 +142,7 @@ def _init_kv_manager(self) -> BaseKVManager:
DisaggregationMode.PREFILL,
self.scheduler.server_args,
self.is_mla_backend,
self.metrics_collector,
)
return kv_manager

Expand Down
61 changes: 60 additions & 1 deletion python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import os
import random
import threading
import time
import warnings
from collections import deque
from contextlib import nullcontext
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

import numpy as np
import requests
Expand All @@ -19,6 +20,8 @@

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.managers.scheduler_metrics_mixin import SchedulerMetricsCollector

#########################
# Constants & Enums
Expand Down Expand Up @@ -347,6 +350,62 @@ def register_disaggregation_server(
)


#########################
# Monitor Metrics
#########################


def get_metrics_collector(scheduler: Scheduler):
if scheduler.enable_metrics:
return scheduler.metrics_collector
return None


class KVCacheTransferLatencyMonitor:
def __init__(
self,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
self.metrics_collector = metrics_collector
self.kvcache_transfer_latency_table: Dict[int, Dict[str, float]] = {}

def collect_begin_timestamp(self, room: int, dst_id: str):
if self.metrics_collector is None:
return

if room not in self.kvcache_transfer_latency_table:
self.kvcache_transfer_latency_table[room] = {dst_id: -time.time()}
else:
if dst_id not in self.kvcache_transfer_latency_table[room]:
self.kvcache_transfer_latency_table[room][dst_id] = 0

self.kvcache_transfer_latency_table[room][dst_id] -= time.time()

def collect_finish_timestamp(self, room: int, dst_id: str, ret: int):
if self.metrics_collector is None:
return

assert room in self.kvcache_transfer_latency_table
assert dst_id in self.kvcache_transfer_latency_table[room]
if ret != 0:
self.kvcache_transfer_latency_table[room][dst_id] = 0
else:
self.kvcache_transfer_latency_table[room][dst_id] += time.time()

def record(self, room: int, dst_id: str):
if self.metrics_collector is not None:
self.metrics_collector.observe_kvcache_transfer_latency(
self.kvcache_transfer_latency_table[room].pop(dst_id)
)

def pop_room(self, room: int):
if (
self.metrics_collector is not None
and room in self.kvcache_transfer_latency_table
):
self.kvcache_transfer_latency_table.pop(room)


#########################
# Misc
#########################
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/managers/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
}
if dp_rank is not None:
labels["dp_rank"] = dp_rank
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
self.metrics_collector = SchedulerMetricsCollector(
labels=labels,
bucket_kvcache_transfer_latency=self.server_args.bucket_kvcache_transfer_latency,
)

def init_kv_events(self, kv_events_config: Optional[str]):
if self.enable_kv_cache_events:
Expand Down
49 changes: 47 additions & 2 deletions python/sglang/srt/metrics/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,18 @@ class SchedulerStats:
num_decode_prealloc_queue_reqs: int = 0
num_decode_transfer_queue_reqs: int = 0
total_retracted_reqs: int = 0
kvcache_transfer_latency: float = 0.0


class SchedulerMetricsCollector:

def __init__(self, labels: Dict[str, str]) -> None:
def __init__(
self,
labels: Dict[str, str],
bucket_kvcache_transfer_latency: Optional[List[float]] = None,
) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Counter, Gauge
from prometheus_client import Counter, Gauge, Histogram

self.labels = labels
self.last_log_time = time.perf_counter()
Expand Down Expand Up @@ -268,10 +273,50 @@ def __init__(self, labels: Dict[str, str]) -> None:
labelnames=labels.keys(),
)

if bucket_kvcache_transfer_latency is None:
bucket_kvcache_transfer_latency = [
0.001,
0.002,
0.004,
0.006,
0.008,
0.01,
0.02,
0.04,
0.06,
0.08,
0.1,
0.2,
0.4,
0.6,
0.8,
1,
2,
4,
6,
8,
10,
20,
40,
60,
80,
100,
]

self.histogram_kvcache_transfer_latency = Histogram(
name="sglang:kvcache_transfer_latency",
documentation="Histogram of kvcache transfer latency in seconds.",
labelnames=labels.keys(),
buckets=bucket_kvcache_transfer_latency,
)

def _log_gauge(self, gauge, data: Union[int, float]) -> None:
# Convenience function for logging to gauge.
gauge.labels(**self.labels).set(data)

def observe_kvcache_transfer_latency(self, value: float) -> None:
self.histogram_kvcache_transfer_latency.labels(**self.labels).observe(value)

def increment_bootstrap_failed_reqs(self) -> None:
self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)

Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ServerArgs:
bucket_time_to_first_token: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
bucket_kvcache_transfer_latency: Optional[List[float]] = None
collect_tokens_histogram: bool = False
decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False
Expand Down Expand Up @@ -1156,6 +1157,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.bucket_e2e_request_latency,
help="The buckets of end-to-end request latency, specified as a list of floats.",
)
parser.add_argument(
"--bucket-kvcache-transfer-latency",
type=float,
nargs="+",
default=ServerArgs.bucket_kvcache_transfer_latency,
help="The buckets of request kvcache transfer latency, specified as a list of floats.",
)
parser.add_argument(
"--collect-tokens-histogram",
action="store_true",
Expand Down
Loading