Skip to content
Merged
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.model_executor.model_runner_kv_cache_mixin import MemoryPoolConfig
from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@
)
from sglang.srt.model_executor.hook_manager import register_forward_hooks
from sglang.srt.model_executor.model_runner_kv_cache_mixin import (
MemoryPoolConfig,
ModelRunnerKVCacheMixin,
)
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
PiecewiseCudaGraphRunner,
)
from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
RemoteInstanceWeightLoaderBackend,
Expand Down
125 changes: 27 additions & 98 deletions python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING

import torch

Expand Down Expand Up @@ -39,25 +38,7 @@

if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner


@dataclass
class MemoryPoolConfig:
"""Resolved memory pool config, shared between target and draft workers."""

max_total_num_tokens: int
max_running_requests: int
full_max_total_num_tokens: Optional[int] = None
swa_max_total_num_tokens: Optional[int] = None

mem_fraction_static: Optional[float] = None

def __post_init__(self):
if self.max_total_num_tokens <= 0:
msg = "Not enough memory. Please try to increase --mem-fraction-static."
if self.mem_fraction_static is not None:
msg += f" Current value: mem_fraction_static={self.mem_fraction_static}"
raise RuntimeError(msg)
from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig


# the ratio of mamba cache pool size to max_running_requests
Expand All @@ -73,9 +54,7 @@ def __post_init__(self):

class ModelRunnerKVCacheMixin:

def _profile_available_bytes(
self: ModelRunner, pre_model_load_memory: int
) -> float:
def _profile_available_bytes(self: ModelRunner, pre_model_load_memory: int) -> int:
post_model_load_memory = get_available_gpu_memory(
self.device,
self.gpu_id,
Expand All @@ -89,48 +68,7 @@ def _profile_available_bytes(
if self.mambaish_config is not None:
rest_memory = self.handle_max_mamba_cache(rest_memory)

return rest_memory * (1 << 30) # return in bytes

def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int):
# Get the number of layers used for KV cache calculation
if self.is_draft_worker:
num_layers = getattr(
self.model_config.hf_config,
"num_nextn_predict_layers",
self.num_effective_layers,
)
elif mambaish := self.mambaish_config:
effective_layer_ids = [
i
for i in mambaish.full_attention_layer_ids
if self.start_layer <= i < self.end_layer
]
num_layers = len(effective_layer_ids)
else:
num_layers = self.num_effective_layers

from sglang.srt.model_executor.pool_configurator import get_cell_size_per_token

cell_size = get_cell_size_per_token(self, num_layers)
if self.spec_algorithm.is_dflash() and not self.is_draft_worker:
from sglang.srt.speculative.dflash_utils import (
scale_kv_cell_size_per_token_for_dflash,
)

draft_num_layers = getattr(self, "dflash_draft_num_layers", None)
if (
draft_num_layers is not None
and int(draft_num_layers) > 0
and int(num_layers) > 0
):
cell_size = scale_kv_cell_size_per_token_for_dflash(
target_cell_size_per_token=cell_size,
target_num_layers=int(num_layers),
draft_num_layers=int(draft_num_layers),
)

available_bytes = self._profile_available_bytes(pre_model_load_memory)
return int(available_bytes) // cell_size
return int(rest_memory * (1 << 30)) # return in bytes

def handle_max_mamba_cache(self: ModelRunner, total_rest_memory):
config = self.mambaish_config
Expand Down Expand Up @@ -240,20 +178,6 @@ def calculate_mla_kv_cache_dim(self: ModelRunner) -> int:

return kv_cache_dim

def _resolve_hybrid_swa_tokens(
self: ModelRunner, token_capacity: int
) -> Tuple[int, int, int]:
"""Split token_capacity into full/swa pools.

Returns (effective_capacity, full_max_total_num_tokens, swa_max_total_num_tokens).
"""
from sglang.srt.model_executor.pool_configurator import (
resolve_hybrid_swa_tokens,
)

assert self.sliding_window_size is not None and self.sliding_window_size > 0
return resolve_hybrid_swa_tokens(self, token_capacity)

def _calculate_mamba_ratio(self: ModelRunner) -> int:
if self.server_args.disable_radix_cache:
return 1
Expand Down Expand Up @@ -690,7 +614,11 @@ def _init_pools(self: ModelRunner):
)

def _apply_token_constraints(self: ModelRunner, token_capacity: int) -> int:
"""Apply external constraints to token capacity: user cap, page alignment, PP sync."""
"""Apply external constraints to token capacity: user cap, PP sync.

Page alignment is handled by the configurator, not here.
If constraints change the value, the configurator re-runs and re-aligns.
"""
user_limit = self.server_args.max_total_tokens

# Apply user-specified upper bound
Expand All @@ -702,10 +630,6 @@ def _apply_token_constraints(self: ModelRunner, token_capacity: int) -> int:
)
token_capacity = min(token_capacity, user_limit)

# Align to page boundary
page_size = self.server_args.page_size
token_capacity = token_capacity // page_size * page_size

# Sync across PP ranks (each may have different layer counts)
if self.pp_size > 1:
tensor = torch.tensor(token_capacity, dtype=torch.int64)
Expand Down Expand Up @@ -753,23 +677,28 @@ def _resolve_memory_pool_config(
self: ModelRunner, pre_model_load_memory: int
) -> MemoryPoolConfig:
"""Profile GPU memory and resolve all pool parameters into a config."""
profiled_tokens = self.profile_max_num_token(pre_model_load_memory)
token_capacity = self._apply_token_constraints(profiled_tokens)
from sglang.srt.model_executor.pool_configurator import (
create_memory_pool_configurator,
)

full_tokens = None
swa_tokens = None
if self.is_hybrid_swa:
token_capacity, full_tokens, swa_tokens = self._resolve_hybrid_swa_tokens(
token_capacity
available_bytes = self._profile_available_bytes(pre_model_load_memory)
page_size = self.server_args.page_size

configurator = create_memory_pool_configurator(self)
config = configurator.calculate_pool_sizes(available_bytes, page_size)

# Apply external constraints (user cap, page alignment, PP sync)
constrained = self._apply_token_constraints(config.max_total_num_tokens)
if constrained != config.max_total_num_tokens:
config = configurator.calculate_pool_sizes_from_max_tokens(
constrained, page_size
)

return MemoryPoolConfig(
max_total_num_tokens=token_capacity,
max_running_requests=self._resolve_max_num_reqs(token_capacity),
full_max_total_num_tokens=full_tokens,
swa_max_total_num_tokens=swa_tokens,
mem_fraction_static=self.server_args.mem_fraction_static,
config.max_running_requests = self._resolve_max_num_reqs(
config.max_total_num_tokens
)
config.mem_fraction_static = self.server_args.mem_fraction_static
return config

def init_memory_pool(self: ModelRunner, pre_model_load_memory: int):
if not self.spec_algorithm.is_none() and self.is_draft_worker:
Expand Down
Loading
Loading