Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 3 additions & 19 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,26 +2354,10 @@ def get_num_allocatable_reqs(self, running_bs):
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
prefill_delayer_single_pass = None
if self.prefill_delayer:
# Get token usage from several pools
token_usage = None
if self.is_hybrid_swa:
_, _, full_token_usage, swa_token_usage, *_ = self._get_swa_token_info()
token_usage = max(full_token_usage, swa_token_usage)
if self.is_hybrid_ssm:
_, _, full_token_usage, mamba_token_usage, *_ = (
self._get_mamba_token_info()
)
token_usage = (
max(token_usage, mamba_token_usage)
if token_usage is not None
else max(full_token_usage, mamba_token_usage)
)
if token_usage is None:
_, token_usage, _, _ = self._get_token_info()

assert token_usage is not None
# Get max usage across all pools for prefill delay decision
max_pool_usage = self.get_pool_stats().get_max_pool_usage()
prefill_delayer_single_pass = PrefillDelayerSinglePassExecutor(
self.prefill_delayer, token_usage=token_usage
self.prefill_delayer, token_usage=max_pool_usage
)

ret = self._get_new_batch_prefill_raw(
Expand Down
230 changes: 158 additions & 72 deletions python/sglang/srt/managers/scheduler_runtime_checker_mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import dataclasses
import logging
import time
import warnings
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional, Tuple

from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.environ import envs
Expand All @@ -20,6 +21,90 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class PoolStats:
# For full pools (required)
full_num_used: int
full_token_usage: float
full_available_size: int
full_evictable_size: int

is_hybrid_swa: bool = False
is_hybrid_ssm: bool = False

# For hybrid-swa pools
swa_num_used: Optional[int] = None
swa_token_usage: Optional[float] = None
swa_available_size: Optional[int] = None
swa_evictable_size: Optional[int] = None

# For mamba pools
mamba_num_used: Optional[int] = None
mamba_usage: Optional[float] = None
mamba_available_size: Optional[int] = None
mamba_evictable_size: Optional[int] = None

def get_kv_token_stats(self) -> Tuple[int, float]:
# NOTE: mamba pool is not included in the "token usage" calculation.
if self.is_hybrid_swa:
num_used = max(self.full_num_used, self.swa_num_used)
token_usage = max(self.full_token_usage, self.swa_token_usage)
else:
num_used = self.full_num_used
token_usage = self.full_token_usage

return num_used, token_usage

def get_max_pool_usage(self) -> float:
usage = self.full_token_usage
if self.is_hybrid_swa:
usage = max(usage, self.swa_token_usage)
if self.is_hybrid_ssm:
usage = max(usage, self.mamba_usage)
assert usage is not None and usage >= 0, f"{usage=} is not valid"
return usage

def get_prefill_usage_msg_parts(self) -> List[str]:
parts = []
if self.is_hybrid_swa:
parts += [
f"full token usage: {self.full_token_usage:.2f}",
f"swa token usage: {self.swa_token_usage:.2f}",
]
if self.is_hybrid_ssm:
if not self.is_hybrid_swa:
parts.append(f"full token usage: {self.full_token_usage:.2f}")
parts.append(f"mamba usage: {self.mamba_usage:.2f}")
if not parts:
parts.append(f"token usage: {self.full_token_usage:.2f}")
return parts

def get_decode_usage_msg_parts(self) -> List[str]:
parts = []
if self.is_hybrid_swa:
parts += [
f"#full token: {self.full_num_used}",
f"full token usage: {self.full_token_usage:.2f}",
f"#swa token: {self.swa_num_used}",
f"swa token usage: {self.swa_token_usage:.2f}",
]
if self.is_hybrid_ssm:
if not self.is_hybrid_swa:
parts += [
f"#full token: {self.full_num_used}",
f"full token usage: {self.full_token_usage:.2f}",
]
parts += [
f"mamba num: {self.mamba_num_used}",
f"mamba usage: {self.mamba_usage:.2f}",
]
if not parts:
parts.append(
f"#token: {self.full_num_used}, token usage: {self.full_token_usage:.2f}"
)
return parts


class SchedulerRuntimeCheckerMixin:
def _session_held_tokens(self: Scheduler) -> int:
if isinstance(self.tree_cache, SessionAwareCache):
Expand All @@ -41,12 +126,36 @@ def _session_held_req_count(self: Scheduler) -> int:
return self.tree_cache.session_held_req_count()
return 0

def _get_token_info(self: Scheduler):
def get_pool_stats(self: Scheduler) -> PoolStats:
if self.is_hybrid_swa:
pool_stats = self._get_swa_token_info()
elif self.is_hybrid_ssm:
return self._get_mamba_token_info()
else:
return self._get_token_info()

# swa + ssm can coexist: overlay mamba fields onto swa stats
if self.is_hybrid_ssm:
mamba_stats = self._get_mamba_token_info()
pool_stats.is_hybrid_ssm = True
pool_stats.mamba_num_used = mamba_stats.mamba_num_used
pool_stats.mamba_usage = mamba_stats.mamba_usage
pool_stats.mamba_available_size = mamba_stats.mamba_available_size
pool_stats.mamba_evictable_size = mamba_stats.mamba_evictable_size

return pool_stats

def _get_token_info(self: Scheduler) -> PoolStats:
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
num_used = self.max_total_num_tokens - (available_size + evictable_size)
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size
return PoolStats(
full_num_used=num_used,
full_token_usage=token_usage,
full_available_size=available_size,
full_evictable_size=evictable_size,
)

def _get_mamba_token_info(self: Scheduler):
is_mamba_radix_cache = (
Expand All @@ -68,18 +177,20 @@ def _get_mamba_token_info(self: Scheduler):
)
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
return (
full_num_used,
mamba_num_used,
full_token_usage,
mamba_usage,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,

return PoolStats(
is_hybrid_ssm=True,
full_num_used=full_num_used,
full_token_usage=full_token_usage,
full_available_size=full_available_size,
full_evictable_size=full_evictable_size,
mamba_num_used=mamba_num_used,
mamba_usage=mamba_usage,
mamba_available_size=mamba_available_size,
mamba_evictable_size=mamba_evictable_size,
)

def _get_swa_token_info(self: Scheduler):
def _get_swa_token_info(self: Scheduler) -> PoolStats:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
Expand All @@ -92,28 +203,27 @@ def _get_swa_token_info(self: Scheduler):
)
full_token_usage = full_num_used / self.full_tokens_per_layer
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
return (
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,

return PoolStats(
is_hybrid_swa=True,
full_num_used=full_num_used,
full_token_usage=full_token_usage,
full_available_size=full_available_size,
full_evictable_size=full_evictable_size,
swa_num_used=swa_num_used,
swa_token_usage=swa_token_usage,
swa_available_size=swa_available_size,
swa_evictable_size=swa_evictable_size,
)

def _check_hybrid_memory(self: Scheduler):
(
full_num_used,
swa_num_used,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
pool_stats = self._get_swa_token_info()
full_num_used = pool_stats.full_num_used
swa_num_used = pool_stats.swa_num_used
full_available_size = pool_stats.full_available_size
full_evictable_size = pool_stats.full_evictable_size
swa_available_size = pool_stats.swa_available_size
swa_evictable_size = pool_stats.swa_evictable_size
session_held_full = self._session_held_full_tokens()
session_held_swa = self._session_held_swa_tokens()

Expand All @@ -132,16 +242,13 @@ def _check_hybrid_memory(self: Scheduler):
return memory_leak, token_msg

def _check_mamba_memory(self: Scheduler):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
pool_stats = self._get_mamba_token_info()
full_num_used = pool_stats.full_num_used
mamba_num_used = pool_stats.mamba_num_used
full_available_size = pool_stats.full_available_size
full_evictable_size = pool_stats.full_evictable_size
mamba_available_size = pool_stats.mamba_available_size
mamba_evictable_size = pool_stats.mamba_evictable_size
session_held = self._session_held_tokens()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size() + session_held
Expand Down Expand Up @@ -181,7 +288,9 @@ def _check_mamba_memory(self: Scheduler):
return memory_leak, token_msg

def _check_radix_cache_memory(self: Scheduler):
_, _, available_size, evictable_size = self._get_token_info()
pool_stats = self._get_token_info()
available_size = pool_stats.full_available_size
evictable_size = pool_stats.full_evictable_size
protected_size = self.tree_cache.protected_size()
session_held = self._session_held_tokens()
memory_leak = (available_size + evictable_size) != (
Expand Down Expand Up @@ -219,7 +328,9 @@ def self_check_during_busy(self: Scheduler):
)
return

_, _, available_size, evictable_size = self._get_token_info()
pool_stats = self._get_token_info()
available_size = pool_stats.full_available_size
evictable_size = pool_stats.full_evictable_size
protected_size = self.tree_cache.protected_size()

uncached_size = self._get_batch_uncached_size(current_batch)
Expand Down Expand Up @@ -294,40 +405,15 @@ def check_memory(self: Scheduler):
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
if self.is_hybrid_swa:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_ssm:
(
num_used,
_,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
token_usage = max(full_token_usage, mamba_usage)
else:
num_used, token_usage, _, _ = self._get_token_info()
pool_stats = self.get_pool_stats()
num_used, _ = pool_stats.get_kv_token_stats()

priority_enabled = self.enable_priority_scheduling
self.stats.num_running_reqs = QueueCount.from_reqs(
self.running_batch.reqs, priority_enabled
)
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(token_usage, 2)
self.stats.token_usage = round(pool_stats.get_max_pool_usage(), 2)
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = QueueCount.from_reqs(
self.waiting_queue, priority_enabled
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -86,7 +86,7 @@ def get_tokens_per_layer_info(self):
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)

def get_memory_pool(self):
def get_memory_pool(self) -> Tuple[ReqToTokenPool, BaseTokenToKVPoolAllocator]:
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool_allocator,
Expand Down
Loading
Loading