diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 5e1a909c58f0..e3ee0ae1c3ae 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -1191,7 +1191,7 @@ def event_loop_normal_disagg_decode(self: Scheduler): self.process_batch_result(batch, result) else: # When the server is idle, do self-check and re-init some states - self.self_check_during_idle() + self.on_idle() # Update last_batch self.last_batch = batch @@ -1224,7 +1224,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler): tmp_batch, tmp_result = self.result_queue.popleft() self.process_batch_result(tmp_batch, tmp_result) elif batch is None: - self.self_check_during_idle() + self.on_idle() # Run sample of the current batch # It depends on the result of the last batch (e.g., grammar), so we run it after the last batch is processed. diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 8e3da245b9e8..0e0a653fdaa0 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -409,7 +409,7 @@ def event_loop_normal_disagg_prefill(self: Scheduler) -> None: result = self.run_batch(batch) self.process_batch_result(batch, result) else: - self.self_check_during_idle() + self.on_idle() self.process_disagg_prefill_inflight_queue() @@ -448,7 +448,7 @@ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None: self.process_batch_result(tmp_batch, tmp_result) elif batch is None: # When the server is idle, do self-check and re-init some states - self.self_check_during_idle() + self.on_idle() self.process_disagg_prefill_inflight_queue() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1ea27ba24ecd..84fa34da8e3d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1370,7 +1370,7 @@ def event_loop_normal(self): self.process_batch_result(batch, result) else: # When the server is idle, do self-check and re-init some states. - self.self_check_during_idle() + self.on_idle() # Update last_batch self.last_batch = batch @@ -1420,7 +1420,7 @@ def pop_and_process(): pop_and_process() elif batch is None: # When the server is idle, do self-check and re-init some states - self.self_check_during_idle() + self.on_idle() # Run sample of the current batch # It depends on the result of the last batch (e.g., grammar), so we run it after the last batch is processed. diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index ba9cc0ac2342..9c0edc315173 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -142,7 +142,7 @@ def event_loop_pp(self: Scheduler): # When the server is idle, self-check and re-init some states if server_is_idle: - self.self_check_during_idle() + self.on_idle() @DynamicGradMode() def event_loop_pp_disagg_prefill(self: Scheduler): @@ -318,7 +318,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): # When the server is idle, self-check and re-init some states if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0: - self.self_check_during_idle() + self.on_idle() @DynamicGradMode() def event_loop_pp_disagg_decode(self: Scheduler): @@ -508,7 +508,7 @@ def event_loop_pp_disagg_decode(self: Scheduler): queue_size += len(self.decode_offload_manager.ongoing_offload) if server_is_idle and queue_size == 0: - self.self_check_during_idle() + self.on_idle() def init_pp_loop_state(self: Scheduler): self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 8fdd4391e10e..b5a0e4a3086c 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -228,45 +228,74 @@ def _get_swa_token_info(self: Scheduler) -> PoolStats: swa_evictable_size=swa_evictable_size, ) - def _check_hybrid_memory(self: Scheduler): - pool_stats = self._get_swa_token_info() - full_num_used = pool_stats.full_num_used - swa_num_used = pool_stats.swa_num_used - full_available_size = pool_stats.full_available_size - full_evictable_size = pool_stats.full_evictable_size - swa_available_size = pool_stats.swa_available_size - swa_evictable_size = pool_stats.swa_evictable_size - session_held_full = self._session_held_full_tokens() - session_held_swa = self._session_held_swa_tokens() - - # Streaming sessions hold tree locks during idle, so tree-protected - # tokens must be accounted for alongside session-held tokens. - full_protected = self.tree_cache.full_protected_size() - swa_protected = self.tree_cache.swa_protected_size() - full_leaked = full_num_used - full_protected - session_held_full - swa_leaked = swa_num_used - swa_protected - session_held_swa - memory_leak = full_leaked != 0 or swa_leaked != 0 - token_msg = ( - f"{full_leaked=}, {swa_leaked=}\n" - f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {full_protected=}, {session_held_full=}\n" - f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {swa_protected=}, {session_held_swa=}\n" + @staticmethod + def _check_pool_invariant( + pool_name: str, + available: int, + evictable: int, + protected: int, + session_held: int, + total: int, + uncached: int = 0, + ) -> Tuple[bool, str]: + """Check: available + evictable + protected + session_held + uncached == total.""" + total_accounted = available + evictable + protected + session_held + uncached + leak = total_accounted != total + msg = ( + f"[{pool_name}] {total=}, {available=}, {evictable=}, " + f"{protected=}, {session_held=}, {uncached=}" ) - return memory_leak, token_msg - - def _check_mamba_memory(self: Scheduler): - pool_stats = self._get_mamba_token_info() - full_num_used = pool_stats.full_num_used - mamba_num_used = pool_stats.mamba_num_used - full_available_size = pool_stats.full_available_size - full_evictable_size = pool_stats.full_evictable_size - mamba_available_size = pool_stats.mamba_available_size - mamba_evictable_size = pool_stats.mamba_evictable_size - session_held = self._session_held_tokens() - memory_leak = ( - full_num_used != self.tree_cache.full_protected_size() + session_held - or mamba_num_used != self.tree_cache.mamba_protected_size() + return leak, msg + + def _check_full_pool( + self: Scheduler, ps: PoolStats, uncached: int = 0 + ) -> Tuple[bool, str]: + if self.is_hybrid_swa: + protected = self.tree_cache.full_protected_size() + session_held = self._session_held_full_tokens() + total = self.full_tokens_per_layer + elif self.is_hybrid_ssm and self.tree_cache.supports_mamba(): + protected = self.tree_cache.full_protected_size() + session_held = self._session_held_tokens() + total = self.token_to_kv_pool_allocator.size + else: + protected = self.tree_cache.protected_size() + session_held = self._session_held_tokens() + total = self.max_total_num_tokens + return self._check_pool_invariant( + "full", + ps.full_available_size, + ps.full_evictable_size, + protected, + session_held, + total, + uncached, + ) + + def _check_swa_pool( + self: Scheduler, ps: PoolStats, uncached: int = 0 + ) -> Tuple[bool, str]: + return self._check_pool_invariant( + "swa", + ps.swa_available_size, + ps.swa_evictable_size, + self.tree_cache.swa_protected_size(), + self._session_held_swa_tokens(), + self.swa_tokens_per_layer, + uncached, + ) + + def _check_mamba_pool(self: Scheduler, ps: PoolStats) -> Tuple[bool, str]: + leak, msg = self._check_pool_invariant( + "mamba", + ps.mamba_available_size, + ps.mamba_evictable_size, + self.tree_cache.mamba_protected_size(), + 0, + self.req_to_token_pool.mamba_pool.size, ) - if memory_leak: + if leak: + # Page-level leak diagnosis for mamba free_full_pages = set( self.token_to_kv_pool_allocator.free_pages.tolist() + self.token_to_kv_pool_allocator.release_pages.tolist() @@ -288,28 +317,11 @@ def _check_mamba_memory(self: Scheduler): leaked_mamba_pages = ( expected_mamba_pages - free_mamba_pages - cached_mamba_pages ) - token_msg = ( - f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n" - f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}, leaked_full_pages={leaked_full_pages if len(leaked_full_pages) > 0 else None}, leaked_mamba_pages={leaked_mamba_pages if len(leaked_mamba_pages) > 0 else None}\n" + msg += ( + f", leaked_full_pages={leaked_full_pages or None}" + f", leaked_mamba_pages={leaked_mamba_pages or None}" ) - else: - token_msg = ( - f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n" - f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n" - ) - return memory_leak, token_msg - - def _check_radix_cache_memory(self: Scheduler): - pool_stats = self._get_token_info() - available_size = pool_stats.full_available_size - evictable_size = pool_stats.full_evictable_size - protected_size = self.tree_cache.protected_size() - session_held = self._session_held_tokens() - memory_leak = (available_size + evictable_size) != ( - self.max_total_num_tokens - protected_size - session_held - ) - token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}, {session_held=}\n" - return memory_leak, token_msg + return leak, msg def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int: ret = 0 @@ -327,10 +339,20 @@ def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int: return ret - def self_check_during_busy(self: Scheduler): + def _get_total_uncached_size(self: Scheduler) -> int: + """Sum uncached tokens across the current and running batches.""" current_batch: ScheduleBatch = self.last_batch + uncached_size = self._get_batch_uncached_size(current_batch) + if ( + current_batch.forward_mode.is_extend() + and self.running_batch is not None + and not self.running_batch.is_empty() + ): + uncached_size += self._get_batch_uncached_size(self.running_batch) + return uncached_size - if current_batch is None: + def self_check_during_busy(self: Scheduler): + if self.last_batch is None: return spec_topk = self.server_args.speculative_eagle_topk or 1 @@ -340,35 +362,12 @@ def self_check_during_busy(self: Scheduler): ) return - pool_stats = self._get_token_info() - available_size = pool_stats.full_available_size - evictable_size = pool_stats.full_evictable_size - protected_size = self.tree_cache.protected_size() - - uncached_size = self._get_batch_uncached_size(current_batch) - - if ( - current_batch.forward_mode.is_extend() - and self.running_batch is not None - and not self.running_batch.is_empty() - ): - uncached_size += self._get_batch_uncached_size(self.running_batch) + uncached = self._get_total_uncached_size() + leak, msg = self._check_full_pool(self.get_pool_stats(), uncached=uncached) if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get() > 1: - log_msg = f"[Mem Check (BUSY)] {available_size=}, {evictable_size=}, {protected_size=}, {uncached_size=}" - logger.info(log_msg) - - session_held = self._session_held_tokens() - total_tokens = ( - available_size - + evictable_size - + protected_size - + uncached_size - + session_held - ) - assert ( - total_tokens == self.max_total_num_tokens - ), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}" + logger.info(f"[Mem Check (BUSY)] {msg}") + assert not leak, f"Mem Leak Detected! {msg}" def _check_req_pool(self: Scheduler): if self.disaggregation_mode == DisaggregationMode.DECODE: @@ -393,59 +392,74 @@ def _check_req_pool(self: Scheduler): msg, ) - def check_memory(self: Scheduler): + def _report_leak(self: Scheduler, pool_name: str, token_msg: str): + msg = f"{pool_name} memory leak detected! {token_msg}" + raise_error_or_warn( + self, + envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE.get(), + "count_memory_leak_warnings", + msg, + ) + + def _check_all_pools( + self: Scheduler, ps: PoolStats, uncached: int = 0 + ) -> Tuple[bool, List[str]]: + """Check memory invariant across all pools. Returns (has_leak, messages).""" + has_leak = False + messages = [] + + full_leak, full_msg = self._check_full_pool(ps, uncached=uncached) + has_leak |= full_leak + messages.append(full_msg) + if self.is_hybrid_swa: - memory_leak, token_msg = self._check_hybrid_memory() - elif self.is_hybrid_ssm and self.tree_cache.supports_mamba(): - memory_leak, token_msg = self._check_mamba_memory() - else: - memory_leak, token_msg = self._check_radix_cache_memory() + swa_leak, swa_msg = self._check_swa_pool(ps) + has_leak |= swa_leak + messages.append(swa_msg) - if memory_leak: - msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" - raise_error_or_warn( - self, - envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE.get(), - "count_memory_leak_warnings", - msg, - ) + if self.is_hybrid_ssm and self.tree_cache.supports_mamba(): + mamba_leak, mamba_msg = self._check_mamba_pool(ps) + has_leak |= mamba_leak + messages.append(mamba_msg) - self._check_req_pool() + return has_leak, messages + def _maybe_log_idle_metrics(self: Scheduler): + """Collect and log metrics every 30 seconds during idle.""" if ( - self.current_scheduler_metrics_enabled - and time.perf_counter() > self.metrics_collector.last_log_time + 30 + not self.current_scheduler_metrics_enabled + or time.perf_counter() <= self.metrics_collector.last_log_time + 30 ): - # During idle time, also collect metrics every 30 seconds. - self.get_pool_stats().update_scheduler_stats(self.stats) + return - priority_enabled = self.enable_priority_scheduling - self.stats.num_running_reqs = QueueCount.from_reqs( - self.running_batch.reqs, priority_enabled + self.get_pool_stats().update_scheduler_stats(self.stats) + + priority_enabled = self.enable_priority_scheduling + self.stats.num_running_reqs = QueueCount.from_reqs( + self.running_batch.reqs, priority_enabled + ) + self.stats.gen_throughput = 0 + self.stats.num_queue_reqs = QueueCount.from_reqs( + self.waiting_queue, priority_enabled + ) + self.stats.num_grammar_queue_reqs = len(self.grammar_manager) + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.stats.num_prefill_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_bootstrap_queue.queue, priority_enabled ) - self.stats.gen_throughput = 0 - self.stats.num_queue_reqs = QueueCount.from_reqs( - self.waiting_queue, priority_enabled + self.stats.num_prefill_inflight_queue_reqs = QueueCount.from_reqs( + self.disagg_prefill_inflight_queue, priority_enabled ) - self.stats.num_grammar_queue_reqs = len(self.grammar_manager) - if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.stats.num_prefill_prealloc_queue_reqs = QueueCount.from_reqs( - self.disagg_prefill_bootstrap_queue.queue, priority_enabled - ) - self.stats.num_prefill_inflight_queue_reqs = QueueCount.from_reqs( - self.disagg_prefill_inflight_queue, priority_enabled - ) - if self.disaggregation_mode == DisaggregationMode.DECODE: - self.stats.num_decode_prealloc_queue_reqs = QueueCount.from_reqs( - self.disagg_decode_prealloc_queue.queue, priority_enabled - ) - self.stats.num_decode_transfer_queue_reqs = QueueCount.from_reqs( - self.disagg_decode_transfer_queue.queue, priority_enabled - ) - self.metrics_collector.log_stats(self.stats) - self._publish_kv_events() + if self.disaggregation_mode == DisaggregationMode.DECODE: + self.stats.num_decode_prealloc_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_prealloc_queue.queue, priority_enabled + ) + self.stats.num_decode_transfer_queue_reqs = QueueCount.from_reqs( + self.disagg_decode_transfer_queue.queue, priority_enabled + ) + self.metrics_collector.log_stats(self.stats) - def check_tree_cache(self: Scheduler): + def _check_tree_cache(self: Scheduler): if ( self.tree_cache.is_tree_cache() and (self.is_hybrid_swa and self.tree_cache.supports_swa()) @@ -453,26 +467,30 @@ def check_tree_cache(self: Scheduler): ): self.tree_cache.sanity_check() - def self_check_during_idle(self: Scheduler): - if self.enable_hisparse and self.hisparse_coordinator.has_ongoing_staging(): + def on_idle(self: Scheduler): + """Idle housekeeping: guard, check, metrics, reset, sleep.""" + if not self.is_fully_idle(): return - if self.disaggregation_mode == DisaggregationMode.PREFILL: - if len(self.disagg_prefill_inflight_queue) > 0: - return - elif self.disaggregation_mode == DisaggregationMode.DECODE: - queue_size = ( - len(self.waiting_queue) - + len(self.disagg_decode_transfer_queue.queue) - + len(self.disagg_decode_prealloc_queue.queue) - ) - if self.server_args.disaggregation_decode_enable_offload_kvcache: - queue_size += len(self.decode_offload_manager.ongoing_offload) - if queue_size: - return - self.check_memory() - self.check_tree_cache() + # memory leak check + has_leak, messages = self._check_all_pools(self.get_pool_stats()) + if has_leak: + self._report_leak("pool", "\n".join(messages)) + self._check_req_pool() + + # tree cache sanity check + self._check_tree_cache() + + # metrics every 30s + self._maybe_log_idle_metrics() + + # kv event publishing + self._publish_kv_events() + + # reset token ratio self.new_token_ratio = self.init_new_token_ratio + + # sleep until next event self.maybe_sleep_on_idle() @@ -482,16 +500,10 @@ def create_scheduler_watchdog( def dump_info() -> str: if scheduler.is_initializing or disable_request_logging(): return "" - if scheduler.is_hybrid_swa: - _, info_msg = scheduler._check_hybrid_memory() - elif scheduler.is_hybrid_ssm and scheduler.tree_cache.supports_mamba(): - _, info_msg = scheduler._check_mamba_memory() - else: - _, info_msg = scheduler._check_radix_cache_memory() + _, messages = scheduler._check_all_pools(scheduler.get_pool_stats()) return ( f"{scheduler.cur_batch.batch_size()=}\n" - f"{scheduler.cur_batch.reqs=}\n" - f"{info_msg}" + f"{scheduler.cur_batch.reqs=}\n" + "\n".join(messages) ) return WatchdogRaw( diff --git a/python/sglang/srt/multiplex/multiplexing_mixin.py b/python/sglang/srt/multiplex/multiplexing_mixin.py index 1e1e858aefb6..9902afe5c16f 100644 --- a/python/sglang/srt/multiplex/multiplexing_mixin.py +++ b/python/sglang/srt/multiplex/multiplexing_mixin.py @@ -128,10 +128,7 @@ def event_loop_pdmux(self: Scheduler): stream_idx > 0 and self.running_batch.is_empty() ) if self.running_batch.is_empty() and self.split_prefill_batch is None: - self.check_memory() - self.check_tree_cache() - self.new_token_ratio = self.init_new_token_ratio - self.maybe_sleep_on_idle() + self.on_idle() if adjust_stream_group: prefill_stream.synchronize()