diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index c1c1ab48b538..8fdd4391e10e 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from sglang.srt.managers.scheduler import Scheduler + from sglang.srt.observability.metrics_collector import SchedulerStats logger = logging.getLogger(__name__) @@ -104,6 +105,17 @@ def get_decode_usage_msg_parts(self) -> List[str]: ) return parts + def update_scheduler_stats(self, stats: SchedulerStats) -> None: + """Update pool-related fields on SchedulerStats.""" + num_used, _ = self.get_kv_token_stats() + stats.num_used_tokens = num_used + stats.token_usage = round(self.get_max_pool_usage(), 2) + stats.full_token_usage = self.full_token_usage + if self.is_hybrid_swa: + stats.swa_token_usage = self.swa_token_usage + if self.is_hybrid_ssm: + stats.mamba_usage = self.mamba_usage + class SchedulerRuntimeCheckerMixin: def _session_held_tokens(self: Scheduler) -> int: @@ -405,15 +417,12 @@ def check_memory(self: Scheduler): and time.perf_counter() > self.metrics_collector.last_log_time + 30 ): # During idle time, also collect metrics every 30 seconds. - pool_stats = self.get_pool_stats() - num_used, _ = pool_stats.get_kv_token_stats() + self.get_pool_stats().update_scheduler_stats(self.stats) priority_enabled = self.enable_priority_scheduling self.stats.num_running_reqs = QueueCount.from_reqs( self.running_batch.reqs, priority_enabled ) - self.stats.num_used_tokens = num_used - self.stats.token_usage = round(pool_stats.get_max_pool_usage(), 2) self.stats.gen_throughput = 0 self.stats.num_queue_reqs = QueueCount.from_reqs( self.waiting_queue, priority_enabled diff --git a/python/sglang/srt/observability/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py index dc28c57eaf09..18ea325bae6a 100644 --- a/python/sglang/srt/observability/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -343,9 +343,6 @@ def report_prefill_stats( self.last_prefill_tokens = prefill_stats.log_input_tokens pool_stats = self.get_pool_stats() - num_used, _ = pool_stats.get_kv_token_stats() - max_pool_usage = pool_stats.get_max_pool_usage() - full_token_usage = pool_stats.full_token_usage token_usage_msg = ", ".join(pool_stats.get_prefill_usage_msg_parts()) + ", " self.stats.new_token_ratio = prefill_stats.new_token_ratio @@ -417,13 +414,7 @@ def report_prefill_stats( self.stats.num_running_reqs = prefill_stats.num_running_reqs self.stats.num_running_reqs_offline_batch = 0 - self.stats.num_used_tokens = num_used - self.stats.token_usage = max_pool_usage - self.stats.full_token_usage = full_token_usage - if pool_stats.is_hybrid_swa: - self.stats.swa_token_usage = pool_stats.swa_token_usage - if pool_stats.is_hybrid_ssm: - self.stats.mamba_usage = pool_stats.mamba_usage + pool_stats.update_scheduler_stats(self.stats) priority_enabled = self.enable_priority_scheduling self.stats.num_queue_reqs = QueueCount.from_reqs( @@ -515,9 +506,6 @@ def report_decode_stats( num_running_reqs_offline_batch = 0 pool_stats = self.get_pool_stats() - num_used, _ = pool_stats.get_kv_token_stats() - max_pool_usage = pool_stats.get_max_pool_usage() - full_token_usage = pool_stats.full_token_usage token_usage_msg = ", ".join(pool_stats.get_decode_usage_msg_parts()) + ", " if RECORD_STEP_TIME: @@ -603,15 +591,7 @@ def report_decode_stats( batch.reqs, priority_enabled ) self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch - self.stats.num_used_tokens = num_used - # maximum usage of all pools - self.stats.token_usage = max_pool_usage - # usage of full attention - self.stats.full_token_usage = full_token_usage - if pool_stats.is_hybrid_swa: - self.stats.swa_token_usage = pool_stats.swa_token_usage - if pool_stats.is_hybrid_ssm: - self.stats.mamba_usage = pool_stats.mamba_usage + pool_stats.update_scheduler_stats(self.stats) self.stats.decode_sum_seq_lens = batch.seq_lens_cpu.sum().item() self.stats.gen_throughput = self.last_gen_throughput self.stats.num_queue_reqs = QueueCount.from_reqs(