Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class Envs:
SGLANG_LOG_GC = EnvBool(False)
SGLANG_LOG_FORWARD_ITERS = EnvBool(False)
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE = EnvBool(True)

# Test & Debug
SGLANG_IS_IN_CI = EnvBool(False)
Expand All @@ -159,7 +158,8 @@ class Envs:
SGLANG_TEST_RETRACT = EnvBool(False)
SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
SGLANG_TEST_RETRACT_NO_PREFILL_BS = EnvInt(2 ** 31)
SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY = EnvInt(0)
SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE = EnvBool(True)

# Scheduler: new token ratio hyperparameters
SGLANG_INIT_NEW_TOKEN_RATIO = EnvFloat(0.7)
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,9 @@ def event_loop_normal(self):

self.last_batch = batch

if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get():
self.self_check_during_busy()

@DynamicGradMode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
Expand Down Expand Up @@ -1050,8 +1053,8 @@ def pop_and_process():
self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch

if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
self._check_runtime_mem_leak()
if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get():
self.self_check_during_busy()

def recv_requests(
self,
Expand Down
56 changes: 24 additions & 32 deletions python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.utils.common import (
ceil_align,
disable_request_logging,
pyspy_dump_schedulers,
raise_error_or_warn,
Expand Down Expand Up @@ -77,7 +78,23 @@ def _check_radix_cache_memory(self: Scheduler):
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
return memory_leak, token_msg

def _check_runtime_mem_leak(self: Scheduler):
def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int:
ret = 0
for req in batch.reqs:
assert req.kv_committed_freed == req.kv_overallocated_freed
uncached_len = 0
if not req.kv_committed_freed:
allocated_len = req.kv_allocated_len
if self.page_size > 1:
allocated_len = ceil_align(allocated_len, self.page_size)
assert req.cache_protected_len % self.page_size == 0
uncached_len = allocated_len - req.cache_protected_len

ret += uncached_len

return ret

def self_check_during_busy(self: Scheduler):
current_batch: ScheduleBatch = self.last_batch

if current_batch is None:
Expand All @@ -86,45 +103,20 @@ def _check_runtime_mem_leak(self: Scheduler):
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()

extend_size = 0
for i, req in enumerate(current_batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
fill_len = len(req.fill_ids) if req.fill_ids is not None else 0
prefix_len = (
len(req.prefix_indices) if req.prefix_indices is not None else 0
)

if current_batch.forward_mode.is_decode():
if req.finished():
unreleased_len = 1
else:
unreleased_len = seq_len - prefix_len
else:
unreleased_len = fill_len - prefix_len

extend_size += unreleased_len
uncached_size = self._get_batch_uncached_size(current_batch)

if (
current_batch.forward_mode.is_extend()
and self.running_batch is not None
and not self.running_batch.is_empty()
and self.running_batch.forward_mode.is_decode()
):
for i, req in enumerate(self.running_batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
prefix_len = (
len(req.prefix_indices) if req.prefix_indices is not None else 0
)

if req.finished():
unreleased_len = 0
else:
unreleased_len = seq_len - prefix_len - 1

extend_size += unreleased_len
uncached_size += self._get_batch_uncached_size(self.running_batch)

total_tokens = available_size + evictable_size + protected_size + extend_size
if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get() > 1:
log_msg = f"[Mem Check (BUSY)] {available_size=}, {evictable_size=}, {protected_size=}, {uncached_size=}"
logger.info(log_msg)

total_tokens = available_size + evictable_size + protected_size + uncached_size
assert (
total_tokens == self.max_total_num_tokens
), f"Mem Leak Detected! {total_tokens=} vs {self.max_total_num_tokens=}"
Expand Down
Loading