diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 9c1321454d46..9c57b428f44f 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -136,7 +136,6 @@ class Envs: SGLANG_LOG_GC = EnvBool(False) SGLANG_LOG_FORWARD_ITERS = EnvBool(False) SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False) - SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE = EnvBool(True) # Test & Debug SGLANG_IS_IN_CI = EnvBool(False) @@ -159,7 +158,8 @@ class Envs: SGLANG_TEST_RETRACT = EnvBool(False) SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3) SGLANG_TEST_RETRACT_NO_PREFILL_BS = EnvInt(2 ** 31) - SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False) + SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY = EnvInt(0) + SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE = EnvBool(True) # Scheduler: new token ratio hyperparameters SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cfbde332fcf4..899af5184915 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1004,6 +1004,9 @@ def event_loop_normal(self): self.last_batch = batch + if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get(): + self.self_check_during_busy() + @DynamicGradMode() def event_loop_overlap(self): """A scheduler loop that overlaps the CPU processing and GPU computation.""" @@ -1050,8 +1053,8 @@ def pop_and_process(): self.launch_batch_sample_if_needed(batch_result) self.last_batch = batch - if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get(): - self._check_runtime_mem_leak() + if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get(): + self.self_check_during_busy() def recv_requests( self, diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 5870a13cd17b..d566bf1ede74 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -12,6 +12,7 @@ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.utils.common import ( + ceil_align, disable_request_logging, pyspy_dump_schedulers, raise_error_or_warn, @@ -77,7 +78,23 @@ def _check_radix_cache_memory(self: Scheduler): token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" return memory_leak, token_msg - def _check_runtime_mem_leak(self: Scheduler): + def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int: + ret = 0 + for req in batch.reqs: + assert req.kv_committed_freed == req.kv_overallocated_freed + uncached_len = 0 + if not req.kv_committed_freed: + allocated_len = req.kv_allocated_len + if self.page_size > 1: + allocated_len = ceil_align(allocated_len, self.page_size) + assert req.cache_protected_len % self.page_size == 0 + uncached_len = allocated_len - req.cache_protected_len + + ret += uncached_len + + return ret + + def self_check_during_busy(self: Scheduler): current_batch: ScheduleBatch = self.last_batch if current_batch is None: @@ -86,45 +103,20 @@ def _check_runtime_mem_leak(self: Scheduler): _, _, available_size, evictable_size = self._get_token_info() protected_size = self.tree_cache.protected_size() - extend_size = 0 - for i, req in enumerate(current_batch.reqs): - seq_len = len(req.origin_input_ids) + len(req.output_ids) - fill_len = len(req.fill_ids) if req.fill_ids is not None else 0 - prefix_len = ( - len(req.prefix_indices) if req.prefix_indices is not None else 0 - ) - - if current_batch.forward_mode.is_decode(): - if req.finished(): - unreleased_len = 1 - else: - unreleased_len = seq_len - prefix_len - else: - unreleased_len = fill_len - prefix_len - - extend_size += unreleased_len + uncached_size = self._get_batch_uncached_size(current_batch) if ( current_batch.forward_mode.is_extend() and self.running_batch is not None and not self.running_batch.is_empty() - and self.running_batch.forward_mode.is_decode() ): - for i, req in enumerate(self.running_batch.reqs): - seq_len = len(req.origin_input_ids) + len(req.output_ids) - prefix_len = ( - len(req.prefix_indices) if req.prefix_indices is not None else 0 - ) - - if req.finished(): - unreleased_len = 0 - else: - unreleased_len = seq_len - prefix_len - 1 - - extend_size += unreleased_len + uncached_size += self._get_batch_uncached_size(self.running_batch) - total_tokens = available_size + evictable_size + protected_size + extend_size + if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get() > 1: + log_msg = f"[Mem Check (BUSY)] {available_size=}, {evictable_size=}, {protected_size=}, {uncached_size=}" + logger.info(log_msg) + total_tokens = available_size + evictable_size + protected_size + uncached_size assert ( total_tokens == self.max_total_num_tokens ), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"