diff --git a/python/sglang/srt/dllm/mixin/scheduler.py b/python/sglang/srt/dllm/mixin/scheduler.py index 5b87eff5fe33..e5404b0e3205 100644 --- a/python/sglang/srt/dllm/mixin/scheduler.py +++ b/python/sglang/srt/dllm/mixin/scheduler.py @@ -27,7 +27,7 @@ def init_diffusion_llm(self: Scheduler): def get_new_batch_dllm(self: Scheduler) -> Optional[ScheduleBatch]: """Generate a new batch for DLLM (Diffusion LLM) scheduling.""" - if self.try_preemption: + if self.enable_priority_preemption: self.running_batch.batch_is_full = False # Early exit if batch is full or no requests available @@ -82,7 +82,7 @@ def _should_skip_prefill(self: Scheduler) -> bool: if ( self.get_num_allocatable_reqs(running_bs) <= 0 and self.dllm_manager.is_empty() - and not self.try_preemption + and not self.enable_priority_preemption ): self.running_batch.batch_is_full = True return True @@ -186,12 +186,8 @@ def _create_dllm_batch( # Record prefill stats for logging after forward from sglang.srt.observability.scheduler_metrics_mixin import PrefillStats - new_batch.prefill_stats = PrefillStats( - log_input_tokens=self.adder.log_input_tokens, - log_hit_tokens=self.adder.log_hit_tokens, - new_token_ratio=self.adder.new_token_ratio, - running_bs=len(self.running_batch.reqs), - num_new_seqs=len(can_run_list), + new_batch.prefill_stats = PrefillStats.from_adder( + self.adder, self.running_batch.reqs, self.enable_priority_scheduling ) return new_batch @@ -209,8 +205,9 @@ def process_dllm_incoming_reqs( # Try preemption if batch is full if self.running_batch.batch_is_full: - if not self.try_preemption or not adder.preempt_to_schedule( - req, self.server_args + if ( + not self.enable_priority_preemption + or not adder.preempt_to_schedule(req, self.server_args) ): break diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7c4806c4b38c..023efb9f32f0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -813,8 +813,13 @@ def init_schedule_policy(self): else "cpu" ), ) - # Enable preemption for priority scheduling. - self.try_preemption = self.enable_priority_scheduling + + # NOTE: preemption is enabled by default for priority scheduling. + self.enable_priority_preemption = ( + self.enable_priority_scheduling + and not self.server_args.disable_priority_preemption + ) + self.init_new_token_ratio = min( envs.SGLANG_INIT_NEW_TOKEN_RATIO.get() * self.server_args.schedule_conservativeness, @@ -1994,7 +1999,7 @@ def _get_new_batch_prefill_raw( for req in ready_grammar_requests: self._add_request_to_queue(req) - if self.try_preemption: + if self.enable_priority_preemption: # Reset batch_is_full to try preemption with a prefill adder. self.running_batch.batch_is_full = False @@ -2004,6 +2009,7 @@ def _get_new_batch_prefill_raw( return None running_bs = len(self.running_batch.reqs) + # Ignore the check if self.chunked_req is not None. # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0, # as the space for the chunked requests has just been released. @@ -2012,7 +2018,7 @@ def _get_new_batch_prefill_raw( if ( self.get_num_allocatable_reqs(running_bs) <= 0 and self.chunked_req is not None - and not self.try_preemption + and not self.enable_priority_preemption ): self.running_batch.batch_is_full = True return None @@ -2090,8 +2096,9 @@ def _get_new_batch_prefill_raw( self.running_batch.batch_is_full = True if self.running_batch.batch_is_full: - if not self.try_preemption or not adder.preempt_to_schedule( - req, self.server_args + if ( + not self.enable_priority_preemption + or not adder.preempt_to_schedule(req, self.server_args) ): break @@ -2174,12 +2181,8 @@ def _get_new_batch_prefill_raw( new_batch.prepare_for_extend() # Record prefill stats for logging after forward - new_batch.prefill_stats = PrefillStats( - log_input_tokens=adder.log_input_tokens, - log_hit_tokens=adder.log_hit_tokens, - new_token_ratio=adder.new_token_ratio, - running_bs=len(self.running_batch.reqs), - num_new_seqs=len(can_run_list), + new_batch.prefill_stats = PrefillStats.from_adder( + adder, self.running_batch.reqs, self.enable_priority_scheduling ) # Mixed-style chunked prefill diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 8ec82c49a4f0..a2fb1f20a1aa 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -9,6 +9,7 @@ from sglang.srt.environ import envs from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.session_aware_cache import SessionAwareCache +from sglang.srt.observability.metrics_collector import QueueCount from sglang.srt.utils.common import ceil_align, raise_error_or_warn from sglang.srt.utils.request_logger import disable_request_logging from sglang.srt.utils.watchdog import WatchdogRaw @@ -301,26 +302,31 @@ def check_memory(self: Scheduler): ) = self._get_mamba_token_info() else: num_used, token_usage, _, _ = self._get_token_info() - num_running_reqs = len(self.running_batch.reqs) - self.stats.num_running_reqs = num_running_reqs + + _enable_ps = self.enable_priority_scheduling + self.stats.num_running_reqs = QueueCount.from_reqs( + self.running_batch.reqs, _enable_ps + ) self.stats.num_used_tokens = num_used self.stats.token_usage = round(token_usage, 2) self.stats.gen_throughput = 0 - self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_queue_reqs = QueueCount.from_reqs( + self.waiting_queue, _enable_ps + ) self.stats.num_grammar_queue_reqs = len(self.grammar_manager) if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.stats.num_prefill_prealloc_queue_reqs = len( - self.disagg_prefill_bootstrap_queue.queue + self.stats.num_prefill_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_bootstrap_queue.queue, _enable_ps ) - self.stats.num_prefill_inflight_queue_reqs = len( - self.disagg_prefill_inflight_queue + self.stats.num_prefill_inflight_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_inflight_queue, _enable_ps ) if self.disaggregation_mode == DisaggregationMode.DECODE: - self.stats.num_decode_prealloc_queue_reqs = len( - self.disagg_decode_prealloc_queue.queue + self.stats.num_decode_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_prealloc_queue.queue, _enable_ps ) - self.stats.num_decode_transfer_queue_reqs = len( - self.disagg_decode_transfer_queue.queue + self.stats.num_decode_transfer_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_transfer_queue.queue, _enable_ps ) self.metrics_collector.log_stats(self.stats) self._publish_kv_events() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bf9badbae9c4..05a8290026d9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -233,6 +233,8 @@ def init_model_config(self): self.context_len = self.model_config.context_len self.image_token_id = self.model_config.image_token_id self.max_req_input_len = None # Will be set later in engine.py + self.enable_priority_scheduling = server_args.enable_priority_scheduling + self.default_priority_value = server_args.default_priority_value speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) @@ -421,6 +423,8 @@ def init_metric_collector_watchdog(self): "model_name": self.server_args.served_model_name, # TODO: Add lora name/path in the future, } + if self.enable_priority_scheduling: + labels["priority"] = "" if self.server_args.tokenizer_metrics_allowed_custom_labels: for label in self.server_args.tokenizer_metrics_allowed_custom_labels: labels[label] = "" @@ -482,6 +486,7 @@ async def generate_request( # Normalize the request obj.normalize_batch_and_arguments() + self._set_default_priority(obj) self._validate_rid(obj) if isinstance(obj, GenerateReqInput) and obj.routed_dp_rank is not None: @@ -1894,11 +1899,13 @@ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): ) custom_labels = getattr(state.obj, "custom_labels", None) - labels = ( - {**self.metrics_collector.labels, **custom_labels} - if custom_labels - else self.metrics_collector.labels - ) + labels = dict(self.metrics_collector.labels) + if custom_labels: + labels.update(custom_labels) + if self.enable_priority_scheduling: + priority = getattr(state.obj, "priority", None) + if priority is not None: + labels["priority"] = str(priority) if ( state.time_stats.first_token_time == 0.0 and self.disaggregation_mode != DisaggregationMode.PREFILL @@ -2367,6 +2374,15 @@ def convert_to_span_attrs( return span_attrs + def _set_default_priority(self, obj: Union[GenerateReqInput, EmbeddingReqInput]): + """Set the default priority value.""" + if ( + self.enable_priority_scheduling + and obj.priority is None + and self.default_priority_value is not None + ): + obj.priority = self.default_priority_value + class ServerStatus(Enum): Up = "Up" diff --git a/python/sglang/srt/observability/metrics_collector.py b/python/sglang/srt/observability/metrics_collector.py index 07c24f950695..4180e1a892f3 100644 --- a/python/sglang/srt/observability/metrics_collector.py +++ b/python/sglang/srt/observability/metrics_collector.py @@ -13,12 +13,15 @@ # ============================================================================== """Utilities for Prometheus Metrics Collection.""" +from __future__ import annotations + import dataclasses import logging import os import time +from collections import Counter from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from sglang.srt.environ import envs from sglang.srt.model_executor.forward_batch_info import ForwardMode @@ -27,8 +30,12 @@ from sglang.srt.utils import get_bool_env_var from sglang.srt.utils.gauge_histogram import GaugeHistogram -SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS") +if TYPE_CHECKING: + from prometheus_client import Gauge + from sglang.srt.managers.schedule_batch import Req + +SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS") logger = logging.getLogger(__name__) @@ -47,10 +54,30 @@ def get_histogram_conf_from_env(env_var_name: str) -> Optional[List[float]]: return [float(x) for x in env_var_value.split(",")] +@dataclass +class QueueCount: + """Holds both the total count and optional per-priority breakdown for a queue.""" + + total: int = 0 + by_priority: Optional[Dict[int, int]] = None + + @classmethod + def from_reqs(cls, reqs: List[Req], enable_priority_scheduling: bool = False): + # NOTE: If requests have priority=None (no --default-priority-value set), + # Counter will produce {None: N}, resulting in priority="None" Prometheus labels. + # Set --default-priority-value when enabling priority scheduling to avoid this. + by_priority = ( + dict(Counter(req.priority for req in reqs)) + if enable_priority_scheduling + else None + ) + return cls(total=len(reqs), by_priority=by_priority) + + @dataclass class SchedulerStats: # Basics - num_running_reqs: int = 0 + num_running_reqs: QueueCount = field(default_factory=QueueCount) num_used_tokens: int = 0 token_usage: float = 0.0 pending_prealloc_token_usage: float = 0.0 @@ -58,7 +85,7 @@ class SchedulerStats: mamba_usage: float = 0.0 decode_sum_seq_lens: int = 0 gen_throughput: float = 0.0 - num_queue_reqs: int = 0 + num_queue_reqs: QueueCount = field(default_factory=QueueCount) num_grammar_queue_reqs: int = 0 num_running_reqs_offline_batch: int = 0 cache_hit_rate: float = 0.0 @@ -74,10 +101,10 @@ class SchedulerStats: num_paused_reqs: int = 0 # PD disaggregation - num_prefill_prealloc_queue_reqs: int = 0 - num_prefill_inflight_queue_reqs: int = 0 - num_decode_prealloc_queue_reqs: int = 0 - num_decode_transfer_queue_reqs: int = 0 + num_prefill_prealloc_queue_reqs: QueueCount = field(default_factory=QueueCount) + num_prefill_inflight_queue_reqs: QueueCount = field(default_factory=QueueCount) + num_decode_prealloc_queue_reqs: QueueCount = field(default_factory=QueueCount) + num_decode_transfer_queue_reqs: QueueCount = field(default_factory=QueueCount) kv_transfer_speed_gb_s: float = 0.0 kv_transfer_latency_ms: float = 0.0 kv_transfer_bootstrap_ms: float = 0.0 @@ -158,6 +185,7 @@ def __init__( self.enable_lora = enable_lora self.enable_hierarchical_cache = enable_hierarchical_cache self.last_log_time = time.perf_counter() + self._known_priorities: Set[int] = set() self.num_running_reqs = Gauge( name="sglang:num_running_reqs", @@ -733,9 +761,22 @@ def __init__( multiprocess_mode="mostrecent", ) - def _log_gauge(self, gauge, data: Union[int, float]) -> None: + def _log_gauge(self, gauge: Gauge, data: Union[int, float, QueueCount]) -> None: # Convenience function for logging to gauge. - gauge.labels(**self.labels).set(data) + if isinstance(data, QueueCount): + # NOTE: When priority scheduling is enabled, the total is recorded under + # priority="" (the default label value). Per-priority breakdowns are recorded + # with priority="". Grafana queries should use priority="" for totals. + gauge.labels(**self.labels).set(data.total) + if data.by_priority is not None: + self._known_priorities.update(data.by_priority.keys()) + for priority in self._known_priorities: + value = data.by_priority.get(priority, 0) + labels = dict(self.labels) + labels["priority"] = str(priority) + gauge.labels(**labels).set(value) + else: + gauge.labels(**self.labels).set(data) def _log_histogram(self, histogram, data: Union[int, float]) -> None: histogram.labels(**self.labels).observe(data) diff --git a/python/sglang/srt/observability/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py index 487482498ed1..b5bea67b9961 100644 --- a/python/sglang/srt/observability/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -5,7 +5,7 @@ import time from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch from sglang.srt.disaggregation.utils import DisaggregationMode @@ -25,6 +25,7 @@ from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.observability.metrics_collector import ( DPCooperationInfo, + QueueCount, SchedulerMetricsCollector, SchedulerStats, compute_routing_key_stats, @@ -34,6 +35,8 @@ from sglang.srt.utils.scheduler_status_logger import SchedulerStatusLogger if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + from sglang.srt.managers.schedule_policy import PrefillAdder from sglang.srt.managers.scheduler import EmbeddingBatchResult, Scheduler logger = logging.getLogger(__name__) @@ -50,9 +53,26 @@ class PrefillStats: log_input_tokens: int log_hit_tokens: int new_token_ratio: float - running_bs: int + num_running_reqs: QueueCount num_new_seqs: int # len(can_run_list) + @classmethod + def from_adder( + cls, + adder: PrefillAdder, + running_reqs: List[Req], + enable_priority_scheduling: bool = False, + ): + return cls( + log_input_tokens=adder.log_input_tokens, + log_hit_tokens=adder.log_hit_tokens, + new_token_ratio=adder.new_token_ratio, + num_running_reqs=QueueCount.from_reqs( + running_reqs, enable_priority_scheduling + ), + num_new_seqs=len(adder.can_run_list), + ) + class KvMetrics: def __init__(self): @@ -118,6 +138,8 @@ def init_metrics( "pp_rank": pp_rank, "moe_ep_rank": self.moe_ep_rank, } + if self.enable_priority_scheduling: + labels["priority"] = "" if dp_rank is not None: labels["dp_rank"] = dp_rank if self.server_args.extra_metric_labels: @@ -217,7 +239,7 @@ def log_prefill_stats( f"#new-token: {prefill_stats.log_input_tokens}, " f"#cached-token: {prefill_stats.log_hit_tokens}, " f"{token_usage_msg}" - f"#running-req: {prefill_stats.running_bs}, " + f"#running-req: {prefill_stats.num_running_reqs.total}, " f"#queue-req: {len(self.waiting_queue)}, " ) @@ -255,7 +277,7 @@ def log_prefill_stats( prefill_stats.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0 ) - self.stats.num_running_reqs = prefill_stats.running_bs + self.stats.num_running_reqs = prefill_stats.num_running_reqs self.stats.num_running_reqs_offline_batch = 0 self.stats.num_used_tokens = num_used self.stats.token_usage = token_usage @@ -263,7 +285,11 @@ def log_prefill_stats( self.stats.swa_token_usage = swa_token_usage if self.is_hybrid_ssm: self.stats.mamba_usage = mamba_usage - self.stats.num_queue_reqs = len(self.waiting_queue) + + _enable_ps = self.enable_priority_scheduling + self.stats.num_queue_reqs = QueueCount.from_reqs( + self.waiting_queue, _enable_ps + ) self.stats.num_grammar_queue_reqs = len(self.grammar_manager) self.stats.cache_hit_rate = cache_hit_rate @@ -276,11 +302,11 @@ def log_prefill_stats( # PD disaggregation if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.stats.num_prefill_prealloc_queue_reqs = len( - self.disagg_prefill_bootstrap_queue.queue + self.stats.num_prefill_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_bootstrap_queue.queue, _enable_ps ) - self.stats.num_prefill_inflight_queue_reqs = len( - self.disagg_prefill_inflight_queue + self.stats.num_prefill_inflight_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_inflight_queue, _enable_ps ) self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms @@ -288,11 +314,11 @@ def log_prefill_stats( self.stats.kv_transfer_alloc_ms = self.kv_transfer_alloc_ms self.stats.kv_transfer_total_mb = self.kv_transfer_total_mb elif self.disaggregation_mode == DisaggregationMode.DECODE: - self.stats.num_decode_prealloc_queue_reqs = len( - self.disagg_decode_prealloc_queue.queue + self.stats.num_decode_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_prealloc_queue.queue, _enable_ps ) - self.stats.num_decode_transfer_queue_reqs = len( - self.disagg_decode_transfer_queue.queue + self.stats.num_decode_transfer_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_transfer_queue.queue, _enable_ps ) # Others @@ -416,8 +442,9 @@ def log_decode_stats( logger.info(msg) if self.enable_metrics: + _enable_ps = self.enable_priority_scheduling # Basics - self.stats.num_running_reqs = num_running_reqs + self.stats.num_running_reqs = QueueCount.from_reqs(batch.reqs, _enable_ps) self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch self.stats.num_used_tokens = num_used self.stats.token_usage = token_usage @@ -427,7 +454,9 @@ def log_decode_stats( self.stats.mamba_usage = mamba_usage self.stats.decode_sum_seq_lens = batch.seq_lens_cpu.sum().item() self.stats.gen_throughput = self.last_gen_throughput - self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_queue_reqs = QueueCount.from_reqs( + self.waiting_queue, _enable_ps + ) self.stats.num_grammar_queue_reqs = len(self.grammar_manager) self.stats.cache_hit_rate = cache_hit_rate @@ -444,20 +473,19 @@ def log_decode_stats( # PD disaggregation if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.stats.num_prefill_prealloc_queue_reqs = len( - self.disagg_prefill_bootstrap_queue.queue + self.stats.num_prefill_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_bootstrap_queue.queue, _enable_ps ) - self.stats.num_prefill_inflight_queue_reqs = len( - self.disagg_prefill_inflight_queue + self.stats.num_prefill_inflight_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_inflight_queue, _enable_ps ) elif self.disaggregation_mode == DisaggregationMode.DECODE: - self.stats.num_decode_prealloc_queue_reqs = len( - self.disagg_decode_prealloc_queue.queue + self.stats.num_decode_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_prealloc_queue.queue, _enable_ps ) - self.stats.num_decode_transfer_queue_reqs = len( - self.disagg_decode_transfer_queue.queue + self.stats.num_decode_transfer_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_transfer_queue.queue, _enable_ps ) - running_routing_keys = [r.routing_key for r in batch.reqs] waiting_routing_keys = [r.routing_key for r in self.waiting_queue] ( @@ -510,13 +538,13 @@ def _emit_kv_metrics(self: Scheduler): return kv_metrics = KvMetrics() - kv_metrics.request_active_slots = self.stats.num_running_reqs + kv_metrics.request_active_slots = self.stats.num_running_reqs.total kv_metrics.request_total_slots = self.max_running_requests kv_metrics.kv_active_blocks = int( self.stats.token_usage * self.max_total_num_tokens ) kv_metrics.kv_total_blocks = self.max_total_num_tokens - kv_metrics.num_requests_waiting = self.stats.num_queue_reqs + kv_metrics.num_requests_waiting = self.stats.num_queue_reqs.total kv_metrics.gpu_cache_usage_perc = self.stats.token_usage kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0 @@ -600,7 +628,7 @@ def calculate_utilization(self: Scheduler): and self.stats.max_running_requests_under_SLO > 0 ): self.stats.utilization = max( - self.stats.num_running_reqs + self.stats.num_running_reqs.total / self.stats.max_running_requests_under_SLO, self.stats.token_usage / 0.9, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 02ee4c6b7462..3da768df000b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -333,6 +333,8 @@ class ServerArgs: prefill_max_requests: Optional[int] = None schedule_policy: str = "fcfs" enable_priority_scheduling: bool = False + disable_priority_preemption: bool = False + default_priority_value: Optional[int] = None abort_on_priority_when_disabled: bool = False schedule_low_priority_values_first: bool = False priority_scheduling_preemption_threshold: int = 10 @@ -3385,6 +3387,18 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.enable_priority_scheduling, help="Enable priority scheduling. Requests with higher priority integer values will be scheduled first by default.", ) + parser.add_argument( + "--disable-priority-preemption", + action="store_true", + default=ServerArgs.disable_priority_preemption, + help="Disable priority scheduling preemption.", + ) + parser.add_argument( + "--default-priority-value", + type=int, + default=ServerArgs.default_priority_value, + help="Default priority for requests without explicit priority.", + ) parser.add_argument( "--abort-on-priority-when-disabled", action="store_true", diff --git a/test/registered/metrics/test_priority_metrics.py b/test/registered/metrics/test_priority_metrics.py new file mode 100644 index 000000000000..ec63f4762a80 --- /dev/null +++ b/test/registered/metrics/test_priority_metrics.py @@ -0,0 +1,174 @@ +import unittest +from typing import Dict, List + +import requests +from prometheus_client.parser import text_string_to_metric_families +from prometheus_client.samples import Sample + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=60, suite="stage-b-test-small-1-gpu") +register_amd_ci(est_time=60, suite="stage-b-test-small-1-gpu-amd") + +_MODEL_NAME = "Qwen/Qwen3-0.6B" + + +def _parse_prometheus_metrics(metrics_text: str) -> Dict[str, List[Sample]]: + result = {} + for family in text_string_to_metric_families(metrics_text): + for sample in family.samples: + if sample.name not in result: + result[sample.name] = [] + result[sample.name].append(sample) + return result + + +def _get_samples_by_name(metrics: Dict[str, List[Sample]], name: str) -> List[Sample]: + return metrics.get(name, []) + + +def _get_sample_value_by_labels(samples: List[Sample], labels: Dict[str, str]) -> float: + for sample in samples: + if all(sample.labels.get(k) == v for k, v in labels.items()): + return sample.value + raise KeyError(f"No sample found with labels {labels}") + + +class TestPriorityMetrics(CustomTestCase): + """Test that priority-based metrics are correctly emitted when + --enable-priority-scheduling is enabled.""" + + @classmethod + def setUpClass(cls): + cls.process = popen_launch_server( + _MODEL_NAME, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-metrics", + "--enable-priority-scheduling", + "--default-priority-value", + "0", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_priority_label_in_gauge_metrics(self): + """Send requests with different priorities and verify that + gauge metrics (num_running_reqs, num_queue_reqs) contain + the priority label dimension.""" + + # Send requests with different priorities to populate metrics + for priority in [1, 5, 10]: + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "Hello", + "sampling_params": {"temperature": 0, "max_new_tokens": 5}, + "priority": priority, + }, + ) + self.assertEqual(response.status_code, 200) + + # Fetch metrics + metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") + self.assertEqual(metrics_response.status_code, 200) + metrics = _parse_prometheus_metrics(metrics_response.text) + + # Verify priority label exists on queue gauge metrics + for metric_name in ["sglang:num_running_reqs", "sglang:num_queue_reqs"]: + samples = _get_samples_by_name(metrics, metric_name) + self.assertGreater(len(samples), 0, f"No samples found for {metric_name}") + + # Should have at least one sample with a non-empty priority label + # (the total has priority="" and per-priority has priority="") + priority_labels = {s.labels.get("priority", "") for s in samples} + self.assertIn( + "", + priority_labels, + f"{metric_name}: missing total (priority='') sample", + ) + + def test_priority_label_in_histogram_metrics(self): + """Send requests with different priorities and verify that + histogram metrics (TTFT, ITL, e2e latency) contain the priority label.""" + + for priority in [1, 5]: + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": {"temperature": 0, "max_new_tokens": 20}, + "priority": priority, + }, + ) + self.assertEqual(response.status_code, 200) + + metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") + self.assertEqual(metrics_response.status_code, 200) + metrics = _parse_prometheus_metrics(metrics_response.text) + + # Check histogram metrics have priority label + histogram_metrics = [ + "sglang:time_to_first_token_seconds", + "sglang:e2e_request_latency_seconds", + ] + for metric_name in histogram_metrics: + # Histogram metrics are emitted as _sum, _count, _bucket + sum_name = f"{metric_name}_sum" + count_name = f"{metric_name}_count" + for suffix_name in [sum_name, count_name]: + samples = _get_samples_by_name(metrics, suffix_name) + if not samples: + continue + # At least one sample should have a non-empty priority label + priority_values = {s.labels.get("priority", "") for s in samples} + non_empty = priority_values - {""} + self.assertGreater( + len(non_empty), + 0, + f"{suffix_name}: expected per-priority samples, " + f"got priority labels: {priority_values}", + ) + + def test_default_priority_value(self): + """Requests without explicit priority should use --default-priority-value (0).""" + + # Send request WITHOUT priority — should get default priority 0 + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "Hello world", + "sampling_params": {"temperature": 0, "max_new_tokens": 5}, + }, + ) + self.assertEqual(response.status_code, 200) + + metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") + self.assertEqual(metrics_response.status_code, 200) + metrics = _parse_prometheus_metrics(metrics_response.text) + + # Check that e2e latency has samples with priority="0" (the default) + e2e_count = _get_samples_by_name( + metrics, "sglang:e2e_request_latency_seconds_count" + ) + priority_values = {s.labels.get("priority", "") for s in e2e_count} + self.assertIn( + "0", + priority_values, + f"Expected priority='0' from default, got: {priority_values}", + ) + + +if __name__ == "__main__": + unittest.main()