From d235d6df4ac6e28f679cc6eaa1d7389855822600 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 16:08:42 -0700 Subject: [PATCH 01/12] introduce MemoryPoolConfigurator class hierarchy --- .../model_runner_kv_cache_mixin.py | 89 ++--- .../srt/model_executor/pool_configurator.py | 360 +++++++++++------- 2 files changed, 252 insertions(+), 197 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index b701524fd1c4..adb3f5d1591a 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional import torch @@ -91,47 +91,6 @@ def _profile_available_bytes( return rest_memory * (1 << 30) # return in bytes - def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): - # Get the number of layers used for KV cache calculation - if self.is_draft_worker: - num_layers = getattr( - self.model_config.hf_config, - "num_nextn_predict_layers", - self.num_effective_layers, - ) - elif mambaish := self.mambaish_config: - effective_layer_ids = [ - i - for i in mambaish.full_attention_layer_ids - if self.start_layer <= i < self.end_layer - ] - num_layers = len(effective_layer_ids) - else: - num_layers = self.num_effective_layers - - from sglang.srt.model_executor.pool_configurator import get_cell_size_per_token - - cell_size = get_cell_size_per_token(self, num_layers) - if self.spec_algorithm.is_dflash() and not self.is_draft_worker: - from sglang.srt.speculative.dflash_utils import ( - scale_kv_cell_size_per_token_for_dflash, - ) - - draft_num_layers = getattr(self, "dflash_draft_num_layers", None) - if ( - draft_num_layers is not None - and int(draft_num_layers) > 0 - and int(num_layers) > 0 - ): - cell_size = scale_kv_cell_size_per_token_for_dflash( - target_cell_size_per_token=cell_size, - target_num_layers=int(num_layers), - draft_num_layers=int(draft_num_layers), - ) - - available_bytes = self._profile_available_bytes(pre_model_load_memory) - return int(available_bytes) // cell_size - def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): config = self.mambaish_config server_args = self.server_args @@ -240,20 +199,6 @@ def calculate_mla_kv_cache_dim(self: ModelRunner) -> int: return kv_cache_dim - def _resolve_hybrid_swa_tokens( - self: ModelRunner, token_capacity: int - ) -> Tuple[int, int, int]: - """Split token_capacity into full/swa pools. - - Returns (effective_capacity, full_max_total_num_tokens, swa_max_total_num_tokens). - """ - from sglang.srt.model_executor.pool_configurator import ( - resolve_hybrid_swa_tokens, - ) - - assert self.sliding_window_size is not None and self.sliding_window_size > 0 - return resolve_hybrid_swa_tokens(self, token_capacity) - def _calculate_mamba_ratio(self: ModelRunner) -> int: if self.server_args.disable_radix_cache: return 1 @@ -753,19 +698,31 @@ def _resolve_memory_pool_config( self: ModelRunner, pre_model_load_memory: int ) -> MemoryPoolConfig: """Profile GPU memory and resolve all pool parameters into a config.""" - profiled_tokens = self.profile_max_num_token(pre_model_load_memory) - token_capacity = self._apply_token_constraints(profiled_tokens) + from sglang.srt.model_executor.pool_configurator import ( + create_memory_pool_configurator, + ) - full_tokens = None - swa_tokens = None - if self.is_hybrid_swa: - token_capacity, full_tokens, swa_tokens = self._resolve_hybrid_swa_tokens( - token_capacity - ) + available_bytes = self._profile_available_bytes(pre_model_load_memory) + page_size = self.server_args.page_size + + configurator = create_memory_pool_configurator(self) + configurator.calculate_pool_sizes(available_bytes, page_size) + + # Apply external constraints (user cap, page alignment, PP sync) + constrained = self._apply_token_constraints(configurator.max_total_num_tokens) + if constrained != configurator.max_total_num_tokens: + configurator.calculate_pool_sizes_from_max_tokens(constrained, page_size) + + full_tokens = getattr(configurator, "full_max_total_num_tokens", None) + swa_tokens = getattr(configurator, "swa_max_total_num_tokens", None) + + max_running_requests = self._resolve_max_num_reqs( + configurator.max_total_num_tokens + ) return MemoryPoolConfig( - max_total_num_tokens=token_capacity, - max_running_requests=self._resolve_max_num_reqs(token_capacity), + max_total_num_tokens=configurator.max_total_num_tokens, + max_running_requests=max_running_requests, full_max_total_num_tokens=full_tokens, swa_max_total_num_tokens=swa_tokens, mem_fraction_static=self.server_args.mem_fraction_static, diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index afb253b380af..4804bb24ecc7 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -1,3 +1,16 @@ +"""Memory pool configurators for profiling and sizing KV cache pools. + +Each model architecture has its own configurator that computes pool sizes +from available GPU memory using a unified coeff+bias model: + + available_bytes = max_tokens * coeff + bias + max_tokens = (available_bytes - bias) / coeff + +Two entry points, same core computation: +- calculate_pool_sizes(available_bytes, page_size): profiling path +- calculate_pool_sizes_from_max_tokens(max_tokens, page_size): constraint path +""" + from __future__ import annotations import logging @@ -16,157 +29,242 @@ logger = logging.getLogger(__name__) -def get_cell_size_per_token(mr: ModelRunner, num_layers: int) -> int: - # args to config cell size - model_config = mr.model_config - kv_cache_dtype = mr.kv_cache_dtype - use_mla_backend = mr.use_mla_backend +class MemoryPoolConfigurator: + """Base class for memory pool configurators. - kv_size = torch._utils._element_size(kv_cache_dtype) - if use_mla_backend: - cell_size = ( - (model_config.kv_lora_rank + model_config.qk_rope_head_dim) - * num_layers - * kv_size - ) - if is_float4_e2m1fn_x2(kv_cache_dtype): - # kv_scale_buffer - scale_block_size = 16 - cell_size = (cell_size // 2) + ( - ( - (model_config.kv_lora_rank + model_config.qk_rope_head_dim) - // scale_block_size - ) - * num_layers - * kv_size - ) + Subclasses compute pool sizes for their architecture via coeff+bias model. + Output fields are read by _resolve_memory_pool_config to build MemoryPoolConfig. + """ - # Add indexer KV cache overhead for NSA models (DeepSeek V3.2) - if is_deepseek_nsa(model_config.hf_config): - index_head_dim = get_nsa_index_head_dim(model_config.hf_config) - indexer_size_per_token = ( - index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4 - ) - element_size = torch._utils._element_size( - NSATokenToKVPool.index_k_with_scale_buffer_dtype - ) - cell_size += indexer_size_per_token * num_layers * element_size - else: - if model_config.is_hybrid_swa: - full_layers_num = len(model_config.full_attention_layer_ids) - swa_layers_num = len(model_config.swa_attention_layer_ids) - - full_per_token = model_config.get_num_kv_heads(get_attention_tp_size()) * ( - model_config.head_dim + model_config.v_head_dim + max_total_num_tokens: int = 0 + + def calculate_pool_sizes(self, available_bytes: int, page_size: int) -> None: + """Profiling path: compute pool sizes from available bytes.""" + raise NotImplementedError + + def calculate_pool_sizes_from_max_tokens( + self, max_total_num_tokens: int, page_size: int + ) -> None: + """Constraint path: recalculate pool sizes from a constrained max_tokens.""" + raise NotImplementedError + + +class DefaultPoolConfigurator(MemoryPoolConfigurator): + """Configurator for standard models: MHA, MLA, NSA, FP4. + + coeff = cell_size (bytes per token across all layers) + bias = 0 + """ + + def __init__(self, mr: ModelRunner): + self._use_mla_backend = mr.use_mla_backend + + # Determine effective number of layers for KV cache + if mambaish := mr.mambaish_config: + effective_layer_ids = [ + i + for i in mambaish.full_attention_layer_ids + if mr.start_layer <= i < mr.end_layer + ] + num_layers = len(effective_layer_ids) + else: + num_layers = mr.num_effective_layers + + self._cell_size = self._compute_cell_size(mr, num_layers) + + # DFLASH: scale cell_size to account for draft model KV cache + if mr.spec_algorithm.is_dflash() and not mr.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + scale_kv_cell_size_per_token_for_dflash, ) - swa_per_token = model_config.get_swa_num_kv_heads( - get_attention_tp_size() - ) * (model_config.swa_head_dim + model_config.swa_v_head_dim) + draft_num_layers = getattr(mr, "dflash_draft_num_layers", None) + if ( + draft_num_layers is not None + and int(draft_num_layers) > 0 + and int(num_layers) > 0 + ): + self._cell_size = scale_kv_cell_size_per_token_for_dflash( + target_cell_size_per_token=self._cell_size, + target_num_layers=int(num_layers), + draft_num_layers=int(draft_num_layers), + ) + + def _compute_cell_size(self, mr: ModelRunner, num_layers: int) -> int: + """Compute per-token KV cache cost in bytes. Subclasses can override.""" + # args to config cell size + model_config = mr.model_config + kv_cache_dtype = mr.kv_cache_dtype + kv_size = torch._utils._element_size(kv_cache_dtype) + tp_size = get_attention_tp_size() + + if self._use_mla_backend: cell_size = ( - full_per_token * full_layers_num + swa_per_token * swa_layers_num - ) * kv_size + (model_config.kv_lora_rank + model_config.qk_rope_head_dim) + * num_layers + * kv_size + ) + if is_float4_e2m1fn_x2(kv_cache_dtype): + # kv_scale_buffer + scale_block_size = 16 + cell_size = (cell_size // 2) + ( + ( + (model_config.kv_lora_rank + model_config.qk_rope_head_dim) + // scale_block_size + ) + * num_layers + * kv_size + ) + + # Add indexer KV cache overhead for NSA models (DeepSeek V3.2) + if is_deepseek_nsa(model_config.hf_config): + index_head_dim = get_nsa_index_head_dim(model_config.hf_config) + indexer_size_per_token = ( + index_head_dim + + index_head_dim // NSATokenToKVPool.quant_block_size * 4 + ) + element_size = torch._utils._element_size( + NSATokenToKVPool.index_k_with_scale_buffer_dtype + ) + cell_size += indexer_size_per_token * num_layers * element_size else: cell_size = ( - model_config.get_num_kv_heads(get_attention_tp_size()) + model_config.get_num_kv_heads(tp_size) * (model_config.head_dim + model_config.v_head_dim) * num_layers * kv_size ) - if is_float4_e2m1fn_x2(kv_cache_dtype): - # kv_scale_buffer - scale_block_size = 16 + if is_float4_e2m1fn_x2(kv_cache_dtype): + # kv_scale_buffer + scale_block_size = 16 + n = model_config.get_num_kv_heads(tp_size) + k = model_config.head_dim + cell_size = (cell_size // 2) + ( + (n * k * num_layers * 2 * kv_size) // scale_block_size + ) + + return cell_size - n = model_config.get_num_kv_heads(get_attention_tp_size()) - k = model_config.head_dim - cell_size = (cell_size // 2) + ( - (n * k * num_layers * 2 * kv_size) // scale_block_size - ) - return cell_size + def calculate_pool_sizes(self, available_bytes: int, page_size: int) -> None: + self.max_total_num_tokens = int(available_bytes) // self._cell_size + self.max_total_num_tokens = self.max_total_num_tokens // page_size * page_size + def calculate_pool_sizes_from_max_tokens( + self, max_total_num_tokens: int, page_size: int + ) -> None: + self.max_total_num_tokens = max_total_num_tokens // page_size * page_size -def resolve_hybrid_swa_tokens( - mr: ModelRunner, token_capacity: int -) -> tuple[int, int, int]: - """Split token_capacity into full/swa pools. - Returns (effective_capacity, full_max_total_num_tokens, swa_max_total_num_tokens). +class HybridSWAPoolConfigurator(MemoryPoolConfigurator): + """Configurator for hybrid sliding window attention models (Gemma2, Command-R, MiMo). + + Splits available memory between full attention and SWA pools. + Does NOT inherit DefaultPoolConfigurator — different coeff model. """ - model_config = mr.model_config - page_size = mr.server_args.page_size - swa_full_tokens_ratio = mr.server_args.swa_full_tokens_ratio - full_layers_num = len(model_config.full_attention_layer_ids) - swa_layers_num = len(model_config.swa_attention_layer_ids) - assert swa_layers_num > 0, "Hybrid SWA model must have at least one SWA layer" + full_max_total_num_tokens: int = 0 + swa_max_total_num_tokens: int = 0 + + def __init__(self, mr: ModelRunner): + model_config = mr.model_config + kv_cache_dtype = mr.kv_cache_dtype + kv_size = torch._utils._element_size(kv_cache_dtype) + tp_size = get_attention_tp_size() - def align_page_size(x: int) -> int: - return (x // page_size) * page_size + self._full_layers_num = len(model_config.full_attention_layer_ids) + self._swa_layers_num = len(model_config.swa_attention_layer_ids) + assert ( + self._swa_layers_num > 0 + ), "Hybrid SWA model must have at least one SWA layer" + + self._swa_full_tokens_ratio = mr.server_args.swa_full_tokens_ratio + + # Full layer per-token memory (bytes) + self._full_per_token = ( + model_config.get_num_kv_heads(tp_size) + * (model_config.head_dim + model_config.v_head_dim) + * kv_size + ) + + # SWA layer per-token memory (bytes) + self._swa_per_token = ( + model_config.get_swa_num_kv_heads(tp_size) + * (model_config.swa_head_dim + model_config.swa_v_head_dim) + * kv_size + ) + + # Profiling cell_size: weighted sum across all layers + # Used to convert between bytes and tokens for the constraint path. + self._cell_size = ( + self._full_per_token * self._full_layers_num + + self._swa_per_token * self._swa_layers_num + ) + + def _solve_pool_sizes(self, total_memory: int, page_size: int) -> None: + """Core computation: split total_memory into full/swa pool sizes.""" + + def align_page_size(x: int) -> int: + return (x // page_size) * page_size + + if self._full_layers_num == 0: + # All layers are SWA + swa_tokens = align_page_size( + total_memory // self._swa_per_token // self._swa_layers_num + ) + self.max_total_num_tokens = swa_tokens + self.full_max_total_num_tokens = 0 + self.swa_max_total_num_tokens = swa_tokens + logger.info( + f"Use sliding window memory pool (all SWA). " + f"swa_layer_tokens={swa_tokens}" + ) + return + + # Solve: + # full_tokens * F * n_full + swa_tokens * S * n_swa = total_memory + # swa_tokens = full_tokens * r + # => full_tokens = total_memory / (F * n_full + r * S * n_swa) + denominator = ( + self._full_per_token * self._full_layers_num + + self._swa_full_tokens_ratio * self._swa_per_token * self._swa_layers_num + ) + assert denominator > 0, ( + f"Invalid denominator={denominator}. " + f"full_per_token={self._full_per_token}, full_layers={self._full_layers_num}, " + f"swa_per_token={self._swa_per_token}, swa_layers={self._swa_layers_num}, " + f"ratio={self._swa_full_tokens_ratio}" + ) + + full_tokens = align_page_size(int(total_memory / denominator)) + swa_tokens = align_page_size(int(full_tokens * self._swa_full_tokens_ratio)) + + self.max_total_num_tokens = full_tokens + self.full_max_total_num_tokens = full_tokens + self.swa_max_total_num_tokens = swa_tokens - if full_layers_num == 0: - # all layers are SWA - swa_tokens = align_page_size(token_capacity) logger.info( - f"Use sliding window memory pool (all SWA). swa_layer_tokens={swa_tokens}" + f"Use sliding window memory pool. " + f"full_layer_tokens={full_tokens}, swa_layer_tokens={swa_tokens}" ) - return swa_tokens, 0, swa_tokens - - # Use unified memory-based allocation for all hybrid SWA models. - # - # Let: - # F = Full layer per-token memory - # S = SWA layer per-token memory (may differ from F) - # r = swa_full_tokens_ratio = swa_tokens / full_tokens - # - # The profile phase computed: - # cell_size = F * n_full + S * n_swa - # token_capacity = rest_memory / cell_size - # => total_memory = token_capacity * (F * n_full + S * n_swa) - # - # We need to solve: - # full_tokens * F * n_full + swa_tokens * S * n_swa = total_memory - # swa_tokens = full_tokens * r - # - # Solution: - # full_tokens = total_memory / (F * n_full + r * S * n_swa) - # = token_capacity * (F * n_full + S * n_swa) / (F * n_full + r * S * n_swa) - - kv_size = torch._utils._element_size(mr.kv_cache_dtype) - - # Full layer per-token memory - full_per_token = ( - model_config.get_num_kv_heads(get_attention_tp_size()) - * (model_config.head_dim + model_config.v_head_dim) - * kv_size - ) - - # SWA layer per-token memory - swa_per_token = ( - model_config.get_swa_num_kv_heads(get_attention_tp_size()) - * (model_config.swa_head_dim + model_config.swa_v_head_dim) - * kv_size - ) - - # Total memory available from profile - total_memory = token_capacity * ( - full_per_token * full_layers_num + swa_per_token * swa_layers_num - ) - - # Solve the equations - denominator = ( - full_per_token * full_layers_num - + swa_full_tokens_ratio * swa_per_token * swa_layers_num - ) - assert ( - denominator > 0 - ), f"Invalid denominator={denominator} for memory-based allocation. full_per_token={full_per_token}, full_layers_num={full_layers_num}, swa_per_token={swa_per_token}, swa_layers_num={swa_layers_num}, swa_full_tokens_ratio={swa_full_tokens_ratio}" - - full_tokens = align_page_size(int(total_memory / denominator)) - swa_tokens = align_page_size(int(full_tokens * swa_full_tokens_ratio)) - - logger.info( - f"Use sliding window memory pool. full_layer_tokens={full_tokens}, swa_layer_tokens={swa_tokens}" - ) - return full_tokens, full_tokens, swa_tokens + + def calculate_pool_sizes(self, available_bytes: int, page_size: int) -> None: + self._solve_pool_sizes(int(available_bytes), page_size) + + def calculate_pool_sizes_from_max_tokens( + self, max_total_num_tokens: int, page_size: int + ) -> None: + # Reconstruct total memory from constrained max_tokens + total_memory = max_total_num_tokens * self._cell_size + self._solve_pool_sizes(total_memory, page_size) + + +def create_memory_pool_configurator( + mr: ModelRunner, +) -> MemoryPoolConfigurator: + """Factory: select the right configurator for the model architecture.""" + if mr.is_hybrid_swa: + return HybridSWAPoolConfigurator(mr) + # Future: MambaPoolConfigurator + return DefaultPoolConfigurator(mr) From 03fa3fd3064cf5f0f7d9bb3755af7e007efda72f Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 16:24:03 -0700 Subject: [PATCH 02/12] unify all-SWA and hybrid SWA pool sizing --- .../srt/model_executor/pool_configurator.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index 4804bb24ecc7..6adc2ca9fc7f 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -208,24 +208,11 @@ def _solve_pool_sizes(self, total_memory: int, page_size: int) -> None: def align_page_size(x: int) -> int: return (x // page_size) * page_size - if self._full_layers_num == 0: - # All layers are SWA - swa_tokens = align_page_size( - total_memory // self._swa_per_token // self._swa_layers_num - ) - self.max_total_num_tokens = swa_tokens - self.full_max_total_num_tokens = 0 - self.swa_max_total_num_tokens = swa_tokens - logger.info( - f"Use sliding window memory pool (all SWA). " - f"swa_layer_tokens={swa_tokens}" - ) - return - # Solve: # full_tokens * F * n_full + swa_tokens * S * n_swa = total_memory # swa_tokens = full_tokens * r # => full_tokens = total_memory / (F * n_full + r * S * n_swa) + # When full_layers_num == 0, denominator = r * S * n_swa, formula still works. denominator = ( self._full_per_token * self._full_layers_num + self._swa_full_tokens_ratio * self._swa_per_token * self._swa_layers_num @@ -240,9 +227,11 @@ def align_page_size(x: int) -> int: full_tokens = align_page_size(int(total_memory / denominator)) swa_tokens = align_page_size(int(full_tokens * self._swa_full_tokens_ratio)) - self.max_total_num_tokens = full_tokens self.full_max_total_num_tokens = full_tokens self.swa_max_total_num_tokens = swa_tokens + self.max_total_num_tokens = ( + full_tokens if self._full_layers_num > 0 else swa_tokens + ) logger.info( f"Use sliding window memory pool. " From 726d6bf6b80674c4d5d9396223489d4562f1017f Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 16:34:44 -0700 Subject: [PATCH 03/12] configurator returns MemoryPoolConfig; move MemoryPoolConfig to pool_configurator --- .../model_runner_kv_cache_mixin.py | 50 +++--------- .../srt/model_executor/pool_configurator.py | 76 +++++++++++++------ 2 files changed, 63 insertions(+), 63 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index adb3f5d1591a..2eb24db03798 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -1,8 +1,7 @@ from __future__ import annotations import logging -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch @@ -30,6 +29,7 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator +from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig # noqa: F401 from sglang.srt.utils.common import ( get_available_gpu_memory, is_float4_e2m1fn_x2, @@ -41,25 +41,6 @@ from sglang.srt.model_executor.model_runner import ModelRunner -@dataclass -class MemoryPoolConfig: - """Resolved memory pool config, shared between target and draft workers.""" - - max_total_num_tokens: int - max_running_requests: int - full_max_total_num_tokens: Optional[int] = None - swa_max_total_num_tokens: Optional[int] = None - - mem_fraction_static: Optional[float] = None - - def __post_init__(self): - if self.max_total_num_tokens <= 0: - msg = "Not enough memory. Please try to increase --mem-fraction-static." - if self.mem_fraction_static is not None: - msg += f" Current value: mem_fraction_static={self.mem_fraction_static}" - raise RuntimeError(msg) - - # the ratio of mamba cache pool size to max_running_requests MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 MAMBA_CACHE_V2_ADDITIONAL_RATIO_OVERLAP = 2 @@ -706,27 +687,20 @@ def _resolve_memory_pool_config( page_size = self.server_args.page_size configurator = create_memory_pool_configurator(self) - configurator.calculate_pool_sizes(available_bytes, page_size) + config = configurator.calculate_pool_sizes(available_bytes, page_size) # Apply external constraints (user cap, page alignment, PP sync) - constrained = self._apply_token_constraints(configurator.max_total_num_tokens) - if constrained != configurator.max_total_num_tokens: - configurator.calculate_pool_sizes_from_max_tokens(constrained, page_size) - - full_tokens = getattr(configurator, "full_max_total_num_tokens", None) - swa_tokens = getattr(configurator, "swa_max_total_num_tokens", None) - - max_running_requests = self._resolve_max_num_reqs( - configurator.max_total_num_tokens - ) + constrained = self._apply_token_constraints(config.max_total_num_tokens) + if constrained != config.max_total_num_tokens: + config = configurator.calculate_pool_sizes_from_max_tokens( + constrained, page_size + ) - return MemoryPoolConfig( - max_total_num_tokens=configurator.max_total_num_tokens, - max_running_requests=max_running_requests, - full_max_total_num_tokens=full_tokens, - swa_max_total_num_tokens=swa_tokens, - mem_fraction_static=self.server_args.mem_fraction_static, + config.max_running_requests = self._resolve_max_num_reqs( + config.max_total_num_tokens ) + config.mem_fraction_static = self.server_args.mem_fraction_static + return config def init_memory_pool(self: ModelRunner, pre_model_load_memory: int): if not self.spec_algorithm.is_none() and self.is_draft_worker: diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index 6adc2ca9fc7f..a62b929d0fa1 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -14,7 +14,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional import torch @@ -23,6 +24,26 @@ from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool from sglang.srt.utils.common import is_float4_e2m1fn_x2 + +@dataclass +class MemoryPoolConfig: + """Resolved memory pool config, shared between target and draft workers.""" + + max_total_num_tokens: int + max_running_requests: Optional[int] = None + full_max_total_num_tokens: Optional[int] = None + swa_max_total_num_tokens: Optional[int] = None + + mem_fraction_static: Optional[float] = None + + def __post_init__(self): + if self.max_total_num_tokens <= 0: + msg = "Not enough memory. Please try to increase --mem-fraction-static." + if self.mem_fraction_static is not None: + msg += f" Current value: mem_fraction_static={self.mem_fraction_static}" + raise RuntimeError(msg) + + if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -33,18 +54,19 @@ class MemoryPoolConfigurator: """Base class for memory pool configurators. Subclasses compute pool sizes for their architecture via coeff+bias model. - Output fields are read by _resolve_memory_pool_config to build MemoryPoolConfig. + Both entry points return MemoryPoolConfig (with max_running_requests=0, + to be filled by the consumer). """ - max_total_num_tokens: int = 0 - - def calculate_pool_sizes(self, available_bytes: int, page_size: int) -> None: + def calculate_pool_sizes( + self, available_bytes: int, page_size: int + ) -> MemoryPoolConfig: """Profiling path: compute pool sizes from available bytes.""" raise NotImplementedError def calculate_pool_sizes_from_max_tokens( self, max_total_num_tokens: int, page_size: int - ) -> None: + ) -> MemoryPoolConfig: """Constraint path: recalculate pool sizes from a constrained max_tokens.""" raise NotImplementedError @@ -147,14 +169,18 @@ def _compute_cell_size(self, mr: ModelRunner, num_layers: int) -> int: return cell_size - def calculate_pool_sizes(self, available_bytes: int, page_size: int) -> None: - self.max_total_num_tokens = int(available_bytes) // self._cell_size - self.max_total_num_tokens = self.max_total_num_tokens // page_size * page_size + def calculate_pool_sizes( + self, available_bytes: int, page_size: int + ) -> MemoryPoolConfig: + max_total_num_tokens = int(available_bytes) // self._cell_size + max_total_num_tokens = max_total_num_tokens // page_size * page_size + return MemoryPoolConfig(max_total_num_tokens=max_total_num_tokens) def calculate_pool_sizes_from_max_tokens( self, max_total_num_tokens: int, page_size: int - ) -> None: - self.max_total_num_tokens = max_total_num_tokens // page_size * page_size + ) -> MemoryPoolConfig: + max_total_num_tokens = max_total_num_tokens // page_size * page_size + return MemoryPoolConfig(max_total_num_tokens=max_total_num_tokens) class HybridSWAPoolConfigurator(MemoryPoolConfigurator): @@ -164,9 +190,6 @@ class HybridSWAPoolConfigurator(MemoryPoolConfigurator): Does NOT inherit DefaultPoolConfigurator — different coeff model. """ - full_max_total_num_tokens: int = 0 - swa_max_total_num_tokens: int = 0 - def __init__(self, mr: ModelRunner): model_config = mr.model_config kv_cache_dtype = mr.kv_cache_dtype @@ -202,7 +225,7 @@ def __init__(self, mr: ModelRunner): + self._swa_per_token * self._swa_layers_num ) - def _solve_pool_sizes(self, total_memory: int, page_size: int) -> None: + def _solve_pool_sizes(self, total_memory: int, page_size: int) -> MemoryPoolConfig: """Core computation: split total_memory into full/swa pool sizes.""" def align_page_size(x: int) -> int: @@ -226,27 +249,30 @@ def align_page_size(x: int) -> int: full_tokens = align_page_size(int(total_memory / denominator)) swa_tokens = align_page_size(int(full_tokens * self._swa_full_tokens_ratio)) - - self.full_max_total_num_tokens = full_tokens - self.swa_max_total_num_tokens = swa_tokens - self.max_total_num_tokens = ( - full_tokens if self._full_layers_num > 0 else swa_tokens - ) + max_total_num_tokens = full_tokens if self._full_layers_num > 0 else swa_tokens logger.info( f"Use sliding window memory pool. " f"full_layer_tokens={full_tokens}, swa_layer_tokens={swa_tokens}" ) - def calculate_pool_sizes(self, available_bytes: int, page_size: int) -> None: - self._solve_pool_sizes(int(available_bytes), page_size) + return MemoryPoolConfig( + max_total_num_tokens=max_total_num_tokens, + full_max_total_num_tokens=full_tokens, + swa_max_total_num_tokens=swa_tokens, + ) + + def calculate_pool_sizes( + self, available_bytes: int, page_size: int + ) -> MemoryPoolConfig: + return self._solve_pool_sizes(int(available_bytes), page_size) def calculate_pool_sizes_from_max_tokens( self, max_total_num_tokens: int, page_size: int - ) -> None: + ) -> MemoryPoolConfig: # Reconstruct total memory from constrained max_tokens total_memory = max_total_num_tokens * self._cell_size - self._solve_pool_sizes(total_memory, page_size) + return self._solve_pool_sizes(total_memory, page_size) def create_memory_pool_configurator( From 893146a96b95912e73b49b3eab268b30df21bf3a Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 16:38:00 -0700 Subject: [PATCH 04/12] import MemoryPoolConfig from pool_configurator directly --- python/sglang/srt/managers/tp_worker.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7f63610da8ee..347805f63e81 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -54,7 +54,7 @@ if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.model_executor.model_runner_kv_cache_mixin import MemoryPoolConfig + from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a00d0f989600..c9f15004da6c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -136,12 +136,12 @@ ) from sglang.srt.model_executor.hook_manager import register_forward_hooks from sglang.srt.model_executor.model_runner_kv_cache_mixin import ( - MemoryPoolConfig, ModelRunnerKVCacheMixin, ) from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( PiecewiseCudaGraphRunner, ) +from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( RemoteInstanceWeightLoaderBackend, diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 2eb24db03798..e7c548aa7d10 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -29,7 +29,6 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator -from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig # noqa: F401 from sglang.srt.utils.common import ( get_available_gpu_memory, is_float4_e2m1fn_x2, @@ -39,6 +38,7 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig # the ratio of mamba cache pool size to max_running_requests From 478132df46d1f38cab04e6f2bf8d2c90c4062ad2 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 17:14:32 -0700 Subject: [PATCH 05/12] cleanup: _compute_cell_size reads from mr only; _profile_available_bytes returns int --- .../srt/model_executor/model_runner_kv_cache_mixin.py | 6 ++---- python/sglang/srt/model_executor/pool_configurator.py | 8 +++----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index e7c548aa7d10..50f2b9af8946 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -54,9 +54,7 @@ class ModelRunnerKVCacheMixin: - def _profile_available_bytes( - self: ModelRunner, pre_model_load_memory: int - ) -> float: + def _profile_available_bytes(self: ModelRunner, pre_model_load_memory: int) -> int: post_model_load_memory = get_available_gpu_memory( self.device, self.gpu_id, @@ -70,7 +68,7 @@ def _profile_available_bytes( if self.mambaish_config is not None: rest_memory = self.handle_max_mamba_cache(rest_memory) - return rest_memory * (1 << 30) # return in bytes + return int(rest_memory * (1 << 30)) # return in bytes def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): config = self.mambaish_config diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index a62b929d0fa1..ff3b2e42efbc 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -79,8 +79,6 @@ class DefaultPoolConfigurator(MemoryPoolConfigurator): """ def __init__(self, mr: ModelRunner): - self._use_mla_backend = mr.use_mla_backend - # Determine effective number of layers for KV cache if mambaish := mr.mambaish_config: effective_layer_ids = [ @@ -121,7 +119,7 @@ def _compute_cell_size(self, mr: ModelRunner, num_layers: int) -> int: kv_size = torch._utils._element_size(kv_cache_dtype) tp_size = get_attention_tp_size() - if self._use_mla_backend: + if mr.use_mla_backend: cell_size = ( (model_config.kv_lora_rank + model_config.qk_rope_head_dim) * num_layers @@ -172,7 +170,7 @@ def _compute_cell_size(self, mr: ModelRunner, num_layers: int) -> int: def calculate_pool_sizes( self, available_bytes: int, page_size: int ) -> MemoryPoolConfig: - max_total_num_tokens = int(available_bytes) // self._cell_size + max_total_num_tokens = available_bytes // self._cell_size max_total_num_tokens = max_total_num_tokens // page_size * page_size return MemoryPoolConfig(max_total_num_tokens=max_total_num_tokens) @@ -265,7 +263,7 @@ def align_page_size(x: int) -> int: def calculate_pool_sizes( self, available_bytes: int, page_size: int ) -> MemoryPoolConfig: - return self._solve_pool_sizes(int(available_bytes), page_size) + return self._solve_pool_sizes(available_bytes, page_size) def calculate_pool_sizes_from_max_tokens( self, max_total_num_tokens: int, page_size: int From 393d338ed0046bc93918cd3425d74afd6f67ffed Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 17:23:34 -0700 Subject: [PATCH 06/12] fix docstring; remove redundant page alignment from _apply_token_constraints --- .../srt/model_executor/model_runner_kv_cache_mixin.py | 10 +++++----- python/sglang/srt/model_executor/pool_configurator.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 50f2b9af8946..9ecc0c414398 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -614,7 +614,11 @@ def _init_pools(self: ModelRunner): ) def _apply_token_constraints(self: ModelRunner, token_capacity: int) -> int: - """Apply external constraints to token capacity: user cap, page alignment, PP sync.""" + """Apply external constraints to token capacity: user cap, PP sync. + + Page alignment is handled by the configurator, not here. + If constraints change the value, the configurator re-runs and re-aligns. + """ user_limit = self.server_args.max_total_tokens # Apply user-specified upper bound @@ -626,10 +630,6 @@ def _apply_token_constraints(self: ModelRunner, token_capacity: int) -> int: ) token_capacity = min(token_capacity, user_limit) - # Align to page boundary - page_size = self.server_args.page_size - token_capacity = token_capacity // page_size * page_size - # Sync across PP ranks (each may have different layer counts) if self.pp_size > 1: tensor = torch.tensor(token_capacity, dtype=torch.int64) diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index ff3b2e42efbc..7bb7d0d317ea 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -54,7 +54,7 @@ class MemoryPoolConfigurator: """Base class for memory pool configurators. Subclasses compute pool sizes for their architecture via coeff+bias model. - Both entry points return MemoryPoolConfig (with max_running_requests=0, + Both entry points return MemoryPoolConfig (with max_running_requests=None, to be filled by the consumer). """ From 2027e460823b1e2e79729f8e9a841ea010995166 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 18:30:43 -0700 Subject: [PATCH 07/12] fix SWA cell_size to include ratio; simplify _solve_pool_sizes --- .../srt/model_executor/pool_configurator.py | 45 +++++++------------ 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index 7bb7d0d317ea..922765248f71 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -216,38 +216,26 @@ def __init__(self, mr: ModelRunner): * kv_size ) - # Profiling cell_size: weighted sum across all layers - # Used to convert between bytes and tokens for the constraint path. + # Bytes per full_token (accounts for SWA ratio). + # full_tokens * _cell_size = total memory consumed by both pools. self._cell_size = ( self._full_per_token * self._full_layers_num - + self._swa_per_token * self._swa_layers_num + + self._swa_full_tokens_ratio * self._swa_per_token * self._swa_layers_num ) - def _solve_pool_sizes(self, total_memory: int, page_size: int) -> MemoryPoolConfig: - """Core computation: split total_memory into full/swa pool sizes.""" + def _solve_pool_sizes( + self, max_total_num_tokens: int, page_size: int + ) -> MemoryPoolConfig: + """Core computation: split max_total_num_tokens into full/swa pool sizes.""" def align_page_size(x: int) -> int: return (x // page_size) * page_size - # Solve: - # full_tokens * F * n_full + swa_tokens * S * n_swa = total_memory - # swa_tokens = full_tokens * r - # => full_tokens = total_memory / (F * n_full + r * S * n_swa) - # When full_layers_num == 0, denominator = r * S * n_swa, formula still works. - denominator = ( - self._full_per_token * self._full_layers_num - + self._swa_full_tokens_ratio * self._swa_per_token * self._swa_layers_num - ) - assert denominator > 0, ( - f"Invalid denominator={denominator}. " - f"full_per_token={self._full_per_token}, full_layers={self._full_layers_num}, " - f"swa_per_token={self._swa_per_token}, swa_layers={self._swa_layers_num}, " - f"ratio={self._swa_full_tokens_ratio}" - ) - - full_tokens = align_page_size(int(total_memory / denominator)) + # full_tokens = max_total_num_tokens (page aligned) + # swa_tokens = full_tokens * ratio (page aligned) + # When full_layers_num == 0, max_total_num_tokens is swa_tokens. + full_tokens = align_page_size(max_total_num_tokens) swa_tokens = align_page_size(int(full_tokens * self._swa_full_tokens_ratio)) - max_total_num_tokens = full_tokens if self._full_layers_num > 0 else swa_tokens logger.info( f"Use sliding window memory pool. " @@ -255,7 +243,9 @@ def align_page_size(x: int) -> int: ) return MemoryPoolConfig( - max_total_num_tokens=max_total_num_tokens, + max_total_num_tokens=( + full_tokens if self._full_layers_num > 0 else swa_tokens + ), full_max_total_num_tokens=full_tokens, swa_max_total_num_tokens=swa_tokens, ) @@ -263,14 +253,13 @@ def align_page_size(x: int) -> int: def calculate_pool_sizes( self, available_bytes: int, page_size: int ) -> MemoryPoolConfig: - return self._solve_pool_sizes(available_bytes, page_size) + max_total_num_tokens = available_bytes // self._cell_size + return self._solve_pool_sizes(max_total_num_tokens, page_size) def calculate_pool_sizes_from_max_tokens( self, max_total_num_tokens: int, page_size: int ) -> MemoryPoolConfig: - # Reconstruct total memory from constrained max_tokens - total_memory = max_total_num_tokens * self._cell_size - return self._solve_pool_sizes(total_memory, page_size) + return self._solve_pool_sizes(max_total_num_tokens, page_size) def create_memory_pool_configurator( From 545ae79f82e8bd80ce265f0bdab8bb2f9d0d4376 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 18:50:01 -0700 Subject: [PATCH 08/12] handle all-SWA edge case; int() for float cell_size division --- .../srt/model_executor/pool_configurator.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index 922765248f71..f687e96cf8db 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -231,9 +231,21 @@ def _solve_pool_sizes( def align_page_size(x: int) -> int: return (x // page_size) * page_size + if self._full_layers_num == 0: + # All layers are SWA — no full pool needed + swa_tokens = align_page_size(max_total_num_tokens) + logger.info( + f"Use sliding window memory pool (all SWA). " + f"swa_layer_tokens={swa_tokens}" + ) + return MemoryPoolConfig( + max_total_num_tokens=swa_tokens, + full_max_total_num_tokens=0, + swa_max_total_num_tokens=swa_tokens, + ) + # full_tokens = max_total_num_tokens (page aligned) # swa_tokens = full_tokens * ratio (page aligned) - # When full_layers_num == 0, max_total_num_tokens is swa_tokens. full_tokens = align_page_size(max_total_num_tokens) swa_tokens = align_page_size(int(full_tokens * self._swa_full_tokens_ratio)) @@ -243,9 +255,7 @@ def align_page_size(x: int) -> int: ) return MemoryPoolConfig( - max_total_num_tokens=( - full_tokens if self._full_layers_num > 0 else swa_tokens - ), + max_total_num_tokens=full_tokens, full_max_total_num_tokens=full_tokens, swa_max_total_num_tokens=swa_tokens, ) @@ -253,7 +263,7 @@ def align_page_size(x: int) -> int: def calculate_pool_sizes( self, available_bytes: int, page_size: int ) -> MemoryPoolConfig: - max_total_num_tokens = available_bytes // self._cell_size + max_total_num_tokens = int(available_bytes // self._cell_size) return self._solve_pool_sizes(max_total_num_tokens, page_size) def calculate_pool_sizes_from_max_tokens( From 6fde4f6f8697eec653634eadc0576eeb51980666 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 19:06:03 -0700 Subject: [PATCH 09/12] fix all-SWA _cell_size: use S*ns without ratio --- .../srt/model_executor/pool_configurator.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index f687e96cf8db..b6c250c169bd 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -216,12 +216,18 @@ def __init__(self, mr: ModelRunner): * kv_size ) - # Bytes per full_token (accounts for SWA ratio). - # full_tokens * _cell_size = total memory consumed by both pools. - self._cell_size = ( - self._full_per_token * self._full_layers_num - + self._swa_full_tokens_ratio * self._swa_per_token * self._swa_layers_num - ) + # Bytes per max_total_num_token. + # For hybrid (full_layers > 0): full_tokens * _cell_size = total memory for both pools. + # For all-SWA (full_layers == 0): swa_tokens * _cell_size = total SWA memory. + if self._full_layers_num == 0: + self._cell_size = self._swa_per_token * self._swa_layers_num + else: + self._cell_size = ( + self._full_per_token * self._full_layers_num + + self._swa_full_tokens_ratio + * self._swa_per_token + * self._swa_layers_num + ) def _solve_pool_sizes( self, max_total_num_tokens: int, page_size: int From a55116f58258db2d15e53a0bac7260c61a61613e Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 22:16:50 -0700 Subject: [PATCH 10/12] add CPU unit tests for pool configurator --- .../model_executor/test_pool_configurator.py | 316 ++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 test/registered/unit/model_executor/test_pool_configurator.py diff --git a/test/registered/unit/model_executor/test_pool_configurator.py b/test/registered/unit/model_executor/test_pool_configurator.py new file mode 100644 index 000000000000..41551c9cf691 --- /dev/null +++ b/test/registered/unit/model_executor/test_pool_configurator.py @@ -0,0 +1,316 @@ +"""Unit tests for pool_configurator.py -- CPU only, no GPU required. + +Tests the end-to-end computation: available_bytes -> MemoryPoolConfig, +verifying tokens are correct, constraints are respected, and memory +invariants hold (tokens * per_token_cost <= available_bytes). +""" + +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from sglang.test.ci.ci_register import register_cpu_ci + +register_cpu_ci(est_time=5, suite="stage-a-test-cpu") + + +def _make_model_runner( + *, + num_kv_heads=4, + head_dim=64, + v_head_dim=64, + num_layers=32, + use_mla_backend=False, + is_hybrid_swa=False, + full_attention_layer_ids=None, + swa_attention_layer_ids=None, + swa_num_kv_heads=None, + swa_head_dim=None, + swa_v_head_dim=None, + swa_full_tokens_ratio=0.5, + page_size=1, + mambaish_config=None, +): + """Create a mock ModelRunner with the fields configurators need.""" + mr = MagicMock() + + mr.use_mla_backend = use_mla_backend + mr.is_draft_worker = False + mr.num_effective_layers = num_layers + mr.start_layer = 0 + mr.end_layer = num_layers + mr.mambaish_config = mambaish_config + mr.is_hybrid_swa = is_hybrid_swa + + mc = SimpleNamespace() + mc.head_dim = head_dim + mc.v_head_dim = v_head_dim + mc.is_hybrid_swa = is_hybrid_swa + mc.full_attention_layer_ids = full_attention_layer_ids or list(range(num_layers)) + mc.swa_attention_layer_ids = swa_attention_layer_ids or [] + mc.swa_head_dim = swa_head_dim or head_dim + mc.swa_v_head_dim = swa_v_head_dim or v_head_dim + mc.get_num_kv_heads = lambda tp_size: num_kv_heads + mc.get_swa_num_kv_heads = lambda tp_size: swa_num_kv_heads or num_kv_heads + mc.hf_config = SimpleNamespace(architectures=["LlamaForCausalLM"]) + mr.model_config = mc + + mr.kv_cache_dtype = "fake_bf16" + + sa = SimpleNamespace() + sa.swa_full_tokens_ratio = swa_full_tokens_ratio + sa.page_size = page_size + mr.server_args = sa + + spec = MagicMock() + spec.is_dflash.return_value = False + spec.is_none.return_value = True + mr.spec_algorithm = spec + + return mr + + +KV_SIZE = 2 # bf16 + + +def _full_per_token(mr): + mc = mr.model_config + return mc.get_num_kv_heads(1) * (mc.head_dim + mc.v_head_dim) * KV_SIZE + + +def _swa_per_token(mr): + mc = mr.model_config + return mc.get_swa_num_kv_heads(1) * (mc.swa_head_dim + mc.swa_v_head_dim) * KV_SIZE + + +def _actual_memory_used(mr, config): + """Compute actual memory consumed by the pool sizes in config.""" + mc = mr.model_config + full_pt = _full_per_token(mr) + swa_pt = _swa_per_token(mr) + nf = len(mc.full_attention_layer_ids) + ns = len(mc.swa_attention_layer_ids) + + if mr.is_hybrid_swa: + full = config.full_max_total_num_tokens or 0 + swa = config.swa_max_total_num_tokens or 0 + return full * full_pt * nf + swa * swa_pt * ns + else: + return config.max_total_num_tokens * full_pt * (nf + ns) + + +class TestDefaultConfigurator(unittest.TestCase): + """Default (MHA): available_bytes -> tokens, memory invariant holds.""" + + def _run(self, available_bytes, page_size=1, **kwargs): + mr = _make_model_runner(page_size=page_size, **kwargs) + with patch("torch._utils._element_size", return_value=KV_SIZE): + from sglang.srt.model_executor.pool_configurator import ( + create_memory_pool_configurator, + ) + + cfg = create_memory_pool_configurator(mr) + config = cfg.calculate_pool_sizes(available_bytes, page_size) + return mr, cfg, config + + def test_memory_utilization(self): + """Memory used should be <= available and within 1% of available.""" + available = 10_000_000 + mr, cfg, config = self._run(available) + used = _actual_memory_used(mr, config) + self.assertLessEqual(used, available) + self.assertGreater(used, available * 0.99) + + def test_page_alignment(self): + available = 10_000_000 + _, _, config = self._run(available, page_size=128) + self.assertEqual(config.max_total_num_tokens % 128, 0) + + def test_constraint_respected(self): + """calculate_pool_sizes_from_max_tokens respects the limit.""" + mr, cfg, config = self._run(10_000_000) + with patch("torch._utils._element_size", return_value=KV_SIZE): + constrained = cfg.calculate_pool_sizes_from_max_tokens(100, page_size=1) + self.assertEqual(constrained.max_total_num_tokens, 100) + + def test_constraint_page_aligned(self): + mr, cfg, _ = self._run(10_000_000, page_size=128) + with patch("torch._utils._element_size", return_value=KV_SIZE): + constrained = cfg.calculate_pool_sizes_from_max_tokens(1000, page_size=128) + self.assertEqual(constrained.max_total_num_tokens, 896) # 1000 // 128 * 128 + + def test_no_swa_fields(self): + _, _, config = self._run(10_000_000) + self.assertIsNone(config.full_max_total_num_tokens) + self.assertIsNone(config.swa_max_total_num_tokens) + + +class TestHybridSWAConfigurator(unittest.TestCase): + """Hybrid SWA: full/swa split, ratio, memory invariant.""" + + def _make_swa_runner(self, full_layers=16, swa_layers=16, ratio=0.5, page_size=1): + return _make_model_runner( + is_hybrid_swa=True, + full_attention_layer_ids=list(range(full_layers)), + swa_attention_layer_ids=list(range(full_layers, full_layers + swa_layers)), + swa_num_kv_heads=4, + page_size=page_size, + swa_full_tokens_ratio=ratio, + ) + + def _run(self, available_bytes, **kwargs): + mr = self._make_swa_runner(**kwargs) + with patch("torch._utils._element_size", return_value=KV_SIZE): + from sglang.srt.model_executor.pool_configurator import ( + create_memory_pool_configurator, + ) + + cfg = create_memory_pool_configurator(mr) + config = cfg.calculate_pool_sizes(available_bytes, mr.server_args.page_size) + return mr, cfg, config + + def test_memory_utilization(self): + """Memory used should be <= available and within 1% of available.""" + available = 10_000_000 + mr, _, config = self._run(available) + used = _actual_memory_used(mr, config) + self.assertLessEqual(used, available) + self.assertGreater(used, available * 0.99) + + def test_ratio_respected(self): + """swa_tokens ~= full_tokens * ratio (within page alignment)""" + available = 10_000_000 + for ratio in [0.25, 0.5, 0.75, 1.0]: + mr, _, config = self._run(available, ratio=ratio, page_size=1) + full = config.full_max_total_num_tokens + swa = config.swa_max_total_num_tokens + self.assertEqual(swa, int(full * ratio), f"ratio={ratio}") + + def test_ratio_with_page_alignment(self): + """With page alignment, swa_tokens = align(full_tokens * ratio)""" + available = 10_000_000 + mr, _, config = self._run(available, ratio=0.5, page_size=128) + full = config.full_max_total_num_tokens + swa = config.swa_max_total_num_tokens + self.assertEqual(full % 128, 0) + self.assertEqual(swa % 128, 0) + self.assertEqual(swa, (int(full * 0.5) // 128) * 128) + + def test_max_total_equals_full(self): + """For hybrid, max_total_num_tokens = full_max_total_num_tokens""" + _, _, config = self._run(10_000_000) + self.assertEqual(config.max_total_num_tokens, config.full_max_total_num_tokens) + + def test_constraint_respected(self): + """full_tokens = constrained value after re-run""" + mr, cfg, _ = self._run(10_000_000, page_size=1) + with patch("torch._utils._element_size", return_value=KV_SIZE): + config = cfg.calculate_pool_sizes_from_max_tokens(200, page_size=1) + self.assertEqual(config.full_max_total_num_tokens, 200) + self.assertEqual(config.swa_max_total_num_tokens, 100) + + def test_constraint_memory_within_budget(self): + """After constraint, memory <= original budget (but less than profiled due to constraint).""" + available = 10_000_000 + mr, cfg, original = self._run(available, page_size=1) + user_limit = original.full_max_total_num_tokens // 2 + with patch("torch._utils._element_size", return_value=KV_SIZE): + config = cfg.calculate_pool_sizes_from_max_tokens( + user_limit, mr.server_args.page_size + ) + used = _actual_memory_used(mr, config) + self.assertLessEqual(used, available) + # constrained should use roughly half the memory + original_used = _actual_memory_used(mr, original) + self.assertAlmostEqual(used / original_used, 0.5, delta=0.01) + + def test_different_layer_counts(self): + """Asymmetric full/swa layer counts""" + available = 10_000_000 + mr, _, config = self._run(available, full_layers=24, swa_layers=8, ratio=0.5) + used = _actual_memory_used(mr, config) + self.assertLessEqual(used, available) + self.assertEqual( + config.swa_max_total_num_tokens, + int(config.full_max_total_num_tokens * 0.5), + ) + + +class TestAllSWAConfigurator(unittest.TestCase): + """All-SWA (full_layers=0): special case.""" + + def _run(self, available_bytes, ratio=0.5, page_size=1): + mr = _make_model_runner( + is_hybrid_swa=True, + full_attention_layer_ids=[], + swa_attention_layer_ids=list(range(32)), + swa_num_kv_heads=4, + swa_full_tokens_ratio=ratio, + page_size=page_size, + ) + with patch("torch._utils._element_size", return_value=KV_SIZE): + from sglang.srt.model_executor.pool_configurator import ( + create_memory_pool_configurator, + ) + + cfg = create_memory_pool_configurator(mr) + config = cfg.calculate_pool_sizes(available_bytes, page_size) + return mr, cfg, config + + def test_full_max_is_zero(self): + _, _, config = self._run(10_000_000) + self.assertEqual(config.full_max_total_num_tokens, 0) + + def test_max_total_equals_swa(self): + _, _, config = self._run(10_000_000) + self.assertEqual(config.max_total_num_tokens, config.swa_max_total_num_tokens) + + def test_memory_utilization(self): + """Memory used should be <= available and within 1% of available.""" + available = 10_000_000 + mr, _, config = self._run(available) + swa_pt = _swa_per_token(mr) + ns = len(mr.model_config.swa_attention_layer_ids) + used = config.swa_max_total_num_tokens * swa_pt * ns + self.assertLessEqual(used, available) + self.assertGreater(used, available * 0.99) + + def test_constraint_respected(self): + mr, cfg, _ = self._run(10_000_000, page_size=1) + with patch("torch._utils._element_size", return_value=KV_SIZE): + config = cfg.calculate_pool_sizes_from_max_tokens(500, page_size=1) + self.assertEqual(config.max_total_num_tokens, 500) + self.assertEqual(config.swa_max_total_num_tokens, 500) + + +class TestFactory(unittest.TestCase): + def test_default_for_non_swa(self): + mr = _make_model_runner(is_hybrid_swa=False) + with patch("torch._utils._element_size", return_value=KV_SIZE): + from sglang.srt.model_executor.pool_configurator import ( + DefaultPoolConfigurator, + create_memory_pool_configurator, + ) + + cfg = create_memory_pool_configurator(mr) + self.assertIsInstance(cfg, DefaultPoolConfigurator) + + def test_swa_for_hybrid(self): + mr = _make_model_runner( + is_hybrid_swa=True, + full_attention_layer_ids=list(range(16)), + swa_attention_layer_ids=list(range(16, 32)), + swa_num_kv_heads=4, + ) + with patch("torch._utils._element_size", return_value=KV_SIZE): + from sglang.srt.model_executor.pool_configurator import ( + HybridSWAPoolConfigurator, + create_memory_pool_configurator, + ) + + cfg = create_memory_pool_configurator(mr) + self.assertIsInstance(cfg, HybridSWAPoolConfigurator) + + +if __name__ == "__main__": + unittest.main() From fa281fba89b1e275be607d3020c274344a01f107 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 22:26:44 -0700 Subject: [PATCH 11/12] fix: mock get_attention_tp_size for CPU-only test --- .../model_executor/test_pool_configurator.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/test/registered/unit/model_executor/test_pool_configurator.py b/test/registered/unit/model_executor/test_pool_configurator.py index 41551c9cf691..1800ec2243c5 100644 --- a/test/registered/unit/model_executor/test_pool_configurator.py +++ b/test/registered/unit/model_executor/test_pool_configurator.py @@ -5,6 +5,7 @@ invariants hold (tokens * per_token_cost <= available_bytes). """ +import contextlib import unittest from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -14,6 +15,19 @@ register_cpu_ci(est_time=5, suite="stage-a-test-cpu") +@contextlib.contextmanager +def mock_cpu_env(kv_size=2, tp_size=1): + """Mock GPU-dependent functions for CPU-only testing.""" + with ( + patch("torch._utils._element_size", return_value=kv_size), + patch( + "sglang.srt.model_executor.pool_configurator.get_attention_tp_size", + return_value=tp_size, + ), + ): + yield + + def _make_model_runner( *, num_kv_heads=4, @@ -104,7 +118,7 @@ class TestDefaultConfigurator(unittest.TestCase): def _run(self, available_bytes, page_size=1, **kwargs): mr = _make_model_runner(page_size=page_size, **kwargs) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): from sglang.srt.model_executor.pool_configurator import ( create_memory_pool_configurator, ) @@ -129,13 +143,13 @@ def test_page_alignment(self): def test_constraint_respected(self): """calculate_pool_sizes_from_max_tokens respects the limit.""" mr, cfg, config = self._run(10_000_000) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): constrained = cfg.calculate_pool_sizes_from_max_tokens(100, page_size=1) self.assertEqual(constrained.max_total_num_tokens, 100) def test_constraint_page_aligned(self): mr, cfg, _ = self._run(10_000_000, page_size=128) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): constrained = cfg.calculate_pool_sizes_from_max_tokens(1000, page_size=128) self.assertEqual(constrained.max_total_num_tokens, 896) # 1000 // 128 * 128 @@ -160,7 +174,7 @@ def _make_swa_runner(self, full_layers=16, swa_layers=16, ratio=0.5, page_size=1 def _run(self, available_bytes, **kwargs): mr = self._make_swa_runner(**kwargs) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): from sglang.srt.model_executor.pool_configurator import ( create_memory_pool_configurator, ) @@ -204,7 +218,7 @@ def test_max_total_equals_full(self): def test_constraint_respected(self): """full_tokens = constrained value after re-run""" mr, cfg, _ = self._run(10_000_000, page_size=1) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): config = cfg.calculate_pool_sizes_from_max_tokens(200, page_size=1) self.assertEqual(config.full_max_total_num_tokens, 200) self.assertEqual(config.swa_max_total_num_tokens, 100) @@ -214,7 +228,7 @@ def test_constraint_memory_within_budget(self): available = 10_000_000 mr, cfg, original = self._run(available, page_size=1) user_limit = original.full_max_total_num_tokens // 2 - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): config = cfg.calculate_pool_sizes_from_max_tokens( user_limit, mr.server_args.page_size ) @@ -248,7 +262,7 @@ def _run(self, available_bytes, ratio=0.5, page_size=1): swa_full_tokens_ratio=ratio, page_size=page_size, ) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): from sglang.srt.model_executor.pool_configurator import ( create_memory_pool_configurator, ) @@ -277,7 +291,7 @@ def test_memory_utilization(self): def test_constraint_respected(self): mr, cfg, _ = self._run(10_000_000, page_size=1) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): config = cfg.calculate_pool_sizes_from_max_tokens(500, page_size=1) self.assertEqual(config.max_total_num_tokens, 500) self.assertEqual(config.swa_max_total_num_tokens, 500) @@ -286,7 +300,7 @@ def test_constraint_respected(self): class TestFactory(unittest.TestCase): def test_default_for_non_swa(self): mr = _make_model_runner(is_hybrid_swa=False) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): from sglang.srt.model_executor.pool_configurator import ( DefaultPoolConfigurator, create_memory_pool_configurator, @@ -302,7 +316,7 @@ def test_swa_for_hybrid(self): swa_attention_layer_ids=list(range(16, 32)), swa_num_kv_heads=4, ) - with patch("torch._utils._element_size", return_value=KV_SIZE): + with mock_cpu_env(): from sglang.srt.model_executor.pool_configurator import ( HybridSWAPoolConfigurator, create_memory_pool_configurator, From d715421666e1315004d0646bddbf5248e3a6eadd Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Wed, 8 Apr 2026 22:36:18 -0700 Subject: [PATCH 12/12] fix: empty list is falsy in mock setup --- .../unit/model_executor/test_pool_configurator.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/registered/unit/model_executor/test_pool_configurator.py b/test/registered/unit/model_executor/test_pool_configurator.py index 1800ec2243c5..e196045f9a8f 100644 --- a/test/registered/unit/model_executor/test_pool_configurator.py +++ b/test/registered/unit/model_executor/test_pool_configurator.py @@ -60,8 +60,14 @@ def _make_model_runner( mc.head_dim = head_dim mc.v_head_dim = v_head_dim mc.is_hybrid_swa = is_hybrid_swa - mc.full_attention_layer_ids = full_attention_layer_ids or list(range(num_layers)) - mc.swa_attention_layer_ids = swa_attention_layer_ids or [] + mc.full_attention_layer_ids = ( + full_attention_layer_ids + if full_attention_layer_ids is not None + else list(range(num_layers)) + ) + mc.swa_attention_layer_ids = ( + swa_attention_layer_ids if swa_attention_layer_ids is not None else [] + ) mc.swa_head_dim = swa_head_dim or head_dim mc.swa_v_head_dim = swa_v_head_dim or v_head_dim mc.get_num_kv_heads = lambda tp_size: num_kv_heads