Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.observability.metrics_collector import SchedulerStats

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,6 +105,17 @@ def get_decode_usage_msg_parts(self) -> List[str]:
)
return parts

def update_scheduler_stats(self, stats: SchedulerStats) -> None:
"""Update pool-related fields on SchedulerStats."""
num_used, _ = self.get_kv_token_stats()
stats.num_used_tokens = num_used
stats.token_usage = round(self.get_max_pool_usage(), 2)
stats.full_token_usage = self.full_token_usage
if self.is_hybrid_swa:
stats.swa_token_usage = self.swa_token_usage
if self.is_hybrid_ssm:
stats.mamba_usage = self.mamba_usage


class SchedulerRuntimeCheckerMixin:
def _session_held_tokens(self: Scheduler) -> int:
Expand Down Expand Up @@ -405,15 +417,12 @@ def check_memory(self: Scheduler):
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
pool_stats = self.get_pool_stats()
num_used, _ = pool_stats.get_kv_token_stats()
self.get_pool_stats().update_scheduler_stats(self.stats)

priority_enabled = self.enable_priority_scheduling
self.stats.num_running_reqs = QueueCount.from_reqs(
self.running_batch.reqs, priority_enabled
)
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(pool_stats.get_max_pool_usage(), 2)
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = QueueCount.from_reqs(
self.waiting_queue, priority_enabled
Expand Down
24 changes: 2 additions & 22 deletions python/sglang/srt/observability/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,6 @@ def report_prefill_stats(
self.last_prefill_tokens = prefill_stats.log_input_tokens

pool_stats = self.get_pool_stats()
num_used, _ = pool_stats.get_kv_token_stats()
max_pool_usage = pool_stats.get_max_pool_usage()
full_token_usage = pool_stats.full_token_usage
token_usage_msg = ", ".join(pool_stats.get_prefill_usage_msg_parts()) + ", "

self.stats.new_token_ratio = prefill_stats.new_token_ratio
Expand Down Expand Up @@ -417,13 +414,7 @@ def report_prefill_stats(

self.stats.num_running_reqs = prefill_stats.num_running_reqs
self.stats.num_running_reqs_offline_batch = 0
self.stats.num_used_tokens = num_used
self.stats.token_usage = max_pool_usage
self.stats.full_token_usage = full_token_usage
if pool_stats.is_hybrid_swa:
self.stats.swa_token_usage = pool_stats.swa_token_usage
if pool_stats.is_hybrid_ssm:
self.stats.mamba_usage = pool_stats.mamba_usage
pool_stats.update_scheduler_stats(self.stats)

priority_enabled = self.enable_priority_scheduling
self.stats.num_queue_reqs = QueueCount.from_reqs(
Expand Down Expand Up @@ -515,9 +506,6 @@ def report_decode_stats(
num_running_reqs_offline_batch = 0

pool_stats = self.get_pool_stats()
num_used, _ = pool_stats.get_kv_token_stats()
max_pool_usage = pool_stats.get_max_pool_usage()
full_token_usage = pool_stats.full_token_usage
token_usage_msg = ", ".join(pool_stats.get_decode_usage_msg_parts()) + ", "

if RECORD_STEP_TIME:
Expand Down Expand Up @@ -603,15 +591,7 @@ def report_decode_stats(
batch.reqs, priority_enabled
)
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
self.stats.num_used_tokens = num_used
# maximum usage of all pools
self.stats.token_usage = max_pool_usage
# usage of full attention
self.stats.full_token_usage = full_token_usage
if pool_stats.is_hybrid_swa:
self.stats.swa_token_usage = pool_stats.swa_token_usage
if pool_stats.is_hybrid_ssm:
self.stats.mamba_usage = pool_stats.mamba_usage
pool_stats.update_scheduler_stats(self.stats)
self.stats.decode_sum_seq_lens = batch.seq_lens_cpu.sum().item()
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = QueueCount.from_reqs(
Expand Down
Loading