Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,7 +1803,7 @@ def init_lora_manager(self):
def _init_lora_cuda_graph_moe_buffers(self):
"""Phase 1 of LoRA CUDA graph init: pre-allocate MoE intermediate buffers.

Must be called before init_memory_pool() so that profile_max_num_token()
Must be called before init_memory_pool() so that memory profiling
sees the reduced available memory and sizes KV cache correctly.
All MoE LoRA layers share one set of buffers (managed by the
lora_backend) since they execute sequentially during forward.
Expand Down
197 changes: 31 additions & 166 deletions python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,83 +72,26 @@ def __post_init__(self):


class ModelRunnerKVCacheMixin:
def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int:
kv_size = torch._utils._element_size(self.kv_cache_dtype)
if self.use_mla_backend:
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* num_layers
* kv_size
)
if is_float4_e2m1fn_x2(self.kv_cache_dtype):
# kv_scale_buffer
scale_block_size = 16
cell_size = (cell_size // 2) + (
(
(
self.model_config.kv_lora_rank
+ self.model_config.qk_rope_head_dim
)
// scale_block_size
)
* num_layers
* kv_size
)

# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
if is_deepseek_nsa(self.model_config.hf_config):
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
indexer_size_per_token = (
index_head_dim
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
)
element_size = torch._utils._element_size(
NSATokenToKVPool.index_k_with_scale_buffer_dtype
)
cell_size += indexer_size_per_token * num_layers * element_size
else:
if self.model_config.is_hybrid_swa:
full_layers_num = len(self.model_config.full_attention_layer_ids)
swa_layers_num = len(self.model_config.swa_attention_layer_ids)

full_per_token = self.model_config.get_num_kv_heads(
get_attention_tp_size()
) * (self.model_config.head_dim + self.model_config.v_head_dim)

swa_per_token = self.model_config.get_swa_num_kv_heads(
get_attention_tp_size()
) * (self.model_config.swa_head_dim + self.model_config.swa_v_head_dim)

cell_size = (
full_per_token * full_layers_num + swa_per_token * swa_layers_num
) * kv_size
else:
cell_size = (
self.model_config.get_num_kv_heads(get_attention_tp_size())
* (self.model_config.head_dim + self.model_config.v_head_dim)
* num_layers
* kv_size
)

if is_float4_e2m1fn_x2(self.kv_cache_dtype):
# kv_scale_buffer
scale_block_size = 16

n = self.model_config.get_num_kv_heads(get_attention_tp_size())
k = self.model_config.head_dim
cell_size = (cell_size // 2) + (
(n * k * num_layers * 2 * kv_size) // scale_block_size
)
return cell_size

def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int):
def _profile_available_bytes(
self: ModelRunner, pre_model_load_memory: int
) -> float:
post_model_load_memory = get_available_gpu_memory(
self.device,
self.gpu_id,
distributed=get_world_group().world_size > 1,
cpu_group=get_world_group().cpu_group,
)

rest_memory = post_model_load_memory - pre_model_load_memory * (
1 - self.mem_fraction_static
)
if self.mambaish_config is not None:
rest_memory = self.handle_max_mamba_cache(rest_memory)

return rest_memory * (1 << 30) # return in bytes

def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int):
# Get the number of layers used for KV cache calculation
if self.is_draft_worker:
num_layers = getattr(
Expand All @@ -166,7 +109,9 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int):
else:
num_layers = self.num_effective_layers

cell_size = self.get_cell_size_per_token(num_layers)
from sglang.srt.model_executor.pool_configurator import get_cell_size_per_token

cell_size = get_cell_size_per_token(self, num_layers)
if self.spec_algorithm.is_dflash() and not self.is_draft_worker:
from sglang.srt.speculative.dflash_utils import (
scale_kv_cell_size_per_token_for_dflash,
Expand All @@ -184,13 +129,8 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int):
draft_num_layers=int(draft_num_layers),
)

rest_memory = post_model_load_memory - pre_model_load_memory * (
1 - self.mem_fraction_static
)
if self.mambaish_config is not None:
rest_memory = self.handle_max_mamba_cache(rest_memory)

return int(rest_memory * (1 << 30)) // cell_size
available_bytes = self._profile_available_bytes(pre_model_load_memory)
return int(available_bytes) // cell_size

def handle_max_mamba_cache(self: ModelRunner, total_rest_memory):
config = self.mambaish_config
Expand Down Expand Up @@ -307,84 +247,12 @@ def _resolve_hybrid_swa_tokens(

Returns (effective_capacity, full_max_total_num_tokens, swa_max_total_num_tokens).
"""
page_size = self.server_args.page_size

assert self.sliding_window_size is not None and self.sliding_window_size > 0
full_layers_num = len(self.model_config.full_attention_layer_ids)
swa_layers_num = len(self.model_config.swa_attention_layer_ids)

assert swa_layers_num > 0, "Hybrid SWA model must have at least one SWA layer"

def align_page_size(x: int) -> int:
return (x // page_size) * page_size

if full_layers_num == 0:
# all layers are SWA
swa_tokens = align_page_size(token_capacity)
logger.info(
f"Use sliding window memory pool (all SWA). swa_layer_tokens={swa_tokens}"
)
return swa_tokens, 0, swa_tokens

swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio

# Use unified memory-based allocation for all hybrid SWA models.
#
# Let:
# F = Full layer per-token memory
# S = SWA layer per-token memory (may differ from F)
# r = swa_full_tokens_ratio = swa_tokens / full_tokens
#
# The profile phase computed:
# cell_size = F * n_full + S * n_swa
# token_capacity = rest_memory / cell_size
# => total_memory = token_capacity * (F * n_full + S * n_swa)
#
# We need to solve:
# full_tokens * F * n_full + swa_tokens * S * n_swa = total_memory
# swa_tokens = full_tokens * r
#
# Solution:
# full_tokens = total_memory / (F * n_full + r * S * n_swa)
# = token_capacity * (F * n_full + S * n_swa) / (F * n_full + r * S * n_swa)

kv_size = torch._utils._element_size(self.kv_cache_dtype)

# Full layer per-token memory
full_per_token = (
self.model_config.get_num_kv_heads(get_attention_tp_size())
* (self.model_config.head_dim + self.model_config.v_head_dim)
* kv_size
)

# SWA layer per-token memory
swa_per_token = (
self.model_config.get_swa_num_kv_heads(get_attention_tp_size())
* (self.model_config.swa_head_dim + self.model_config.swa_v_head_dim)
* kv_size
from sglang.srt.model_executor.pool_configurator import (
resolve_hybrid_swa_tokens,
)

# Total memory available from profile
total_memory = token_capacity * (
full_per_token * full_layers_num + swa_per_token * swa_layers_num
)

# Solve the equations
denominator = (
full_per_token * full_layers_num
+ swa_full_tokens_ratio * swa_per_token * swa_layers_num
)
assert (
denominator > 0
), f"Invalid denominator={denominator} for memory-based allocation. full_per_token={full_per_token}, full_layers_num={full_layers_num}, swa_per_token={swa_per_token}, swa_layers_num={swa_layers_num}, swa_full_tokens_ratio={swa_full_tokens_ratio}"

full_tokens = align_page_size(int(total_memory / denominator))
swa_tokens = align_page_size(int(full_tokens * swa_full_tokens_ratio))

logger.info(
f"Use sliding window memory pool. full_layer_tokens={full_tokens}, swa_layer_tokens={swa_tokens}"
)
return full_tokens, full_tokens, swa_tokens
assert self.sliding_window_size is not None and self.sliding_window_size > 0
return resolve_hybrid_swa_tokens(self, token_capacity)

def _calculate_mamba_ratio(self: ModelRunner) -> int:
if self.server_args.disable_radix_cache:
Expand Down Expand Up @@ -821,37 +689,34 @@ def _init_pools(self: ModelRunner):
self.token_to_kv_pool_allocator.full_to_swa_index_mapping
)

def _resolve_token_capacity(self: ModelRunner, profiled_tokens: int) -> int:
"""Compute final token pool capacity from profiled value,
applying user cap, page alignment, and PP sync"""
def _apply_token_constraints(self: ModelRunner, token_capacity: int) -> int:
"""Apply external constraints to token capacity: user cap, page alignment, PP sync."""
user_limit = self.server_args.max_total_tokens

# Apply user-specified upper bound
if user_limit is not None:
if user_limit > profiled_tokens:
if user_limit > token_capacity:
logging.warning(
f"max_total_tokens={user_limit} is larger than the profiled value "
f"{profiled_tokens}. Use the profiled value instead."
f"{token_capacity}. Use the profiled value instead."
)
capacity = min(profiled_tokens, user_limit)
else:
capacity = profiled_tokens
token_capacity = min(token_capacity, user_limit)

# Align to page boundary
page_size = self.server_args.page_size
capacity = capacity // page_size * page_size
token_capacity = token_capacity // page_size * page_size

# Sync across PP ranks (each may have different layer counts)
if self.pp_size > 1:
tensor = torch.tensor(capacity, dtype=torch.int64)
tensor = torch.tensor(token_capacity, dtype=torch.int64)
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MIN,
group=get_world_group().cpu_group,
)
capacity = tensor.item()
token_capacity = tensor.item()

return capacity
return token_capacity

def _resolve_max_num_reqs(self: ModelRunner, token_capacity: int) -> int:
"""Compute max concurrent requests (per dp worker) from the finalized
Expand Down Expand Up @@ -889,7 +754,7 @@ def _resolve_memory_pool_config(
) -> MemoryPoolConfig:
"""Profile GPU memory and resolve all pool parameters into a config."""
profiled_tokens = self.profile_max_num_token(pre_model_load_memory)
token_capacity = self._resolve_token_capacity(profiled_tokens)
token_capacity = self._apply_token_constraints(profiled_tokens)

full_tokens = None
swa_tokens = None
Expand Down
Loading
Loading