Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 336 additions & 0 deletions test/registered/unit/model_executor/test_pool_configurator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
"""Unit tests for pool_configurator.py -- CPU only, no GPU required.

Tests the end-to-end computation: available_bytes -> MemoryPoolConfig,
verifying tokens are correct, constraints are respected, and memory
invariants hold (tokens * per_token_cost <= available_bytes).
"""

import contextlib
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

from sglang.test.ci.ci_register import register_cpu_ci

register_cpu_ci(est_time=5, suite="stage-a-test-cpu")


@contextlib.contextmanager
def mock_cpu_env(kv_size=2, tp_size=1):
"""Mock GPU-dependent functions for CPU-only testing."""
with (
patch("torch._utils._element_size", return_value=kv_size),
patch(
"sglang.srt.model_executor.pool_configurator.get_attention_tp_size",
return_value=tp_size,
),
):
yield


def _make_model_runner(
*,
num_kv_heads=4,
head_dim=64,
v_head_dim=64,
num_layers=32,
use_mla_backend=False,
is_hybrid_swa=False,
full_attention_layer_ids=None,
swa_attention_layer_ids=None,
swa_num_kv_heads=None,
swa_head_dim=None,
swa_v_head_dim=None,
swa_full_tokens_ratio=0.5,
page_size=1,
mambaish_config=None,
):
"""Create a mock ModelRunner with the fields configurators need."""
mr = MagicMock()

mr.use_mla_backend = use_mla_backend
mr.is_draft_worker = False
mr.num_effective_layers = num_layers
mr.start_layer = 0
mr.end_layer = num_layers
mr.mambaish_config = mambaish_config
mr.is_hybrid_swa = is_hybrid_swa

mc = SimpleNamespace()
mc.head_dim = head_dim
mc.v_head_dim = v_head_dim
mc.is_hybrid_swa = is_hybrid_swa
mc.full_attention_layer_ids = (
full_attention_layer_ids
if full_attention_layer_ids is not None
else list(range(num_layers))
)
mc.swa_attention_layer_ids = (
swa_attention_layer_ids if swa_attention_layer_ids is not None else []
)
mc.swa_head_dim = swa_head_dim or head_dim
mc.swa_v_head_dim = swa_v_head_dim or v_head_dim
mc.get_num_kv_heads = lambda tp_size: num_kv_heads
mc.get_swa_num_kv_heads = lambda tp_size: swa_num_kv_heads or num_kv_heads
mc.hf_config = SimpleNamespace(architectures=["LlamaForCausalLM"])
mr.model_config = mc

mr.kv_cache_dtype = "fake_bf16"

sa = SimpleNamespace()
sa.swa_full_tokens_ratio = swa_full_tokens_ratio
sa.page_size = page_size
mr.server_args = sa

spec = MagicMock()
spec.is_dflash.return_value = False
spec.is_none.return_value = True
mr.spec_algorithm = spec

return mr


KV_SIZE = 2 # bf16


def _full_per_token(mr):
mc = mr.model_config
return mc.get_num_kv_heads(1) * (mc.head_dim + mc.v_head_dim) * KV_SIZE


def _swa_per_token(mr):
mc = mr.model_config
return mc.get_swa_num_kv_heads(1) * (mc.swa_head_dim + mc.swa_v_head_dim) * KV_SIZE


def _actual_memory_used(mr, config):
"""Compute actual memory consumed by the pool sizes in config."""
mc = mr.model_config
full_pt = _full_per_token(mr)
swa_pt = _swa_per_token(mr)
nf = len(mc.full_attention_layer_ids)
ns = len(mc.swa_attention_layer_ids)

if mr.is_hybrid_swa:
full = config.full_max_total_num_tokens or 0
swa = config.swa_max_total_num_tokens or 0
return full * full_pt * nf + swa * swa_pt * ns
else:
return config.max_total_num_tokens * full_pt * (nf + ns)


class TestDefaultConfigurator(unittest.TestCase):
"""Default (MHA): available_bytes -> tokens, memory invariant holds."""

def _run(self, available_bytes, page_size=1, **kwargs):
mr = _make_model_runner(page_size=page_size, **kwargs)
with mock_cpu_env():
from sglang.srt.model_executor.pool_configurator import (
create_memory_pool_configurator,
)

cfg = create_memory_pool_configurator(mr)
config = cfg.calculate_pool_sizes(available_bytes, page_size)
return mr, cfg, config

def test_memory_utilization(self):
"""Memory used should be <= available and within 1% of available."""
available = 10_000_000
mr, cfg, config = self._run(available)
used = _actual_memory_used(mr, config)
self.assertLessEqual(used, available)
self.assertGreater(used, available * 0.99)

def test_page_alignment(self):
available = 10_000_000
_, _, config = self._run(available, page_size=128)
self.assertEqual(config.max_total_num_tokens % 128, 0)

def test_constraint_respected(self):
"""calculate_pool_sizes_from_max_tokens respects the limit."""
mr, cfg, config = self._run(10_000_000)
with mock_cpu_env():
constrained = cfg.calculate_pool_sizes_from_max_tokens(100, page_size=1)
self.assertEqual(constrained.max_total_num_tokens, 100)

def test_constraint_page_aligned(self):
mr, cfg, _ = self._run(10_000_000, page_size=128)
with mock_cpu_env():
constrained = cfg.calculate_pool_sizes_from_max_tokens(1000, page_size=128)
self.assertEqual(constrained.max_total_num_tokens, 896) # 1000 // 128 * 128

def test_no_swa_fields(self):
_, _, config = self._run(10_000_000)
self.assertIsNone(config.full_max_total_num_tokens)
self.assertIsNone(config.swa_max_total_num_tokens)


class TestHybridSWAConfigurator(unittest.TestCase):
"""Hybrid SWA: full/swa split, ratio, memory invariant."""

def _make_swa_runner(self, full_layers=16, swa_layers=16, ratio=0.5, page_size=1):
return _make_model_runner(
is_hybrid_swa=True,
full_attention_layer_ids=list(range(full_layers)),
swa_attention_layer_ids=list(range(full_layers, full_layers + swa_layers)),
swa_num_kv_heads=4,
page_size=page_size,
swa_full_tokens_ratio=ratio,
)

def _run(self, available_bytes, **kwargs):
mr = self._make_swa_runner(**kwargs)
with mock_cpu_env():
from sglang.srt.model_executor.pool_configurator import (
create_memory_pool_configurator,
)

cfg = create_memory_pool_configurator(mr)
config = cfg.calculate_pool_sizes(available_bytes, mr.server_args.page_size)
return mr, cfg, config

def test_memory_utilization(self):
"""Memory used should be <= available and within 1% of available."""
available = 10_000_000
mr, _, config = self._run(available)
used = _actual_memory_used(mr, config)
self.assertLessEqual(used, available)
self.assertGreater(used, available * 0.99)

def test_ratio_respected(self):
"""swa_tokens ~= full_tokens * ratio (within page alignment)"""
available = 10_000_000
for ratio in [0.25, 0.5, 0.75, 1.0]:
mr, _, config = self._run(available, ratio=ratio, page_size=1)
full = config.full_max_total_num_tokens
swa = config.swa_max_total_num_tokens
self.assertEqual(swa, int(full * ratio), f"ratio={ratio}")

def test_ratio_with_page_alignment(self):
"""With page alignment, swa_tokens = align(full_tokens * ratio)"""
available = 10_000_000
mr, _, config = self._run(available, ratio=0.5, page_size=128)
full = config.full_max_total_num_tokens
swa = config.swa_max_total_num_tokens
self.assertEqual(full % 128, 0)
self.assertEqual(swa % 128, 0)
self.assertEqual(swa, (int(full * 0.5) // 128) * 128)

def test_max_total_equals_full(self):
"""For hybrid, max_total_num_tokens = full_max_total_num_tokens"""
_, _, config = self._run(10_000_000)
self.assertEqual(config.max_total_num_tokens, config.full_max_total_num_tokens)

def test_constraint_respected(self):
"""full_tokens = constrained value after re-run"""
mr, cfg, _ = self._run(10_000_000, page_size=1)
with mock_cpu_env():
config = cfg.calculate_pool_sizes_from_max_tokens(200, page_size=1)
self.assertEqual(config.full_max_total_num_tokens, 200)
self.assertEqual(config.swa_max_total_num_tokens, 100)

def test_constraint_memory_within_budget(self):
"""After constraint, memory <= original budget (but less than profiled due to constraint)."""
available = 10_000_000
mr, cfg, original = self._run(available, page_size=1)
user_limit = original.full_max_total_num_tokens // 2
with mock_cpu_env():
config = cfg.calculate_pool_sizes_from_max_tokens(
user_limit, mr.server_args.page_size
)
used = _actual_memory_used(mr, config)
self.assertLessEqual(used, available)
# constrained should use roughly half the memory
original_used = _actual_memory_used(mr, original)
self.assertAlmostEqual(used / original_used, 0.5, delta=0.01)

def test_different_layer_counts(self):
"""Asymmetric full/swa layer counts"""
available = 10_000_000
mr, _, config = self._run(available, full_layers=24, swa_layers=8, ratio=0.5)
used = _actual_memory_used(mr, config)
self.assertLessEqual(used, available)
self.assertEqual(
config.swa_max_total_num_tokens,
int(config.full_max_total_num_tokens * 0.5),
)


class TestAllSWAConfigurator(unittest.TestCase):
"""All-SWA (full_layers=0): special case."""

def _run(self, available_bytes, ratio=0.5, page_size=1):
mr = _make_model_runner(
is_hybrid_swa=True,
full_attention_layer_ids=[],
swa_attention_layer_ids=list(range(32)),
swa_num_kv_heads=4,
swa_full_tokens_ratio=ratio,
page_size=page_size,
)
with mock_cpu_env():
from sglang.srt.model_executor.pool_configurator import (
create_memory_pool_configurator,
)

cfg = create_memory_pool_configurator(mr)
config = cfg.calculate_pool_sizes(available_bytes, page_size)
return mr, cfg, config

def test_full_max_is_zero(self):
_, _, config = self._run(10_000_000)
self.assertEqual(config.full_max_total_num_tokens, 0)

def test_max_total_equals_swa(self):
_, _, config = self._run(10_000_000)
self.assertEqual(config.max_total_num_tokens, config.swa_max_total_num_tokens)

def test_memory_utilization(self):
"""Memory used should be <= available and within 1% of available."""
available = 10_000_000
mr, _, config = self._run(available)
swa_pt = _swa_per_token(mr)
ns = len(mr.model_config.swa_attention_layer_ids)
used = config.swa_max_total_num_tokens * swa_pt * ns
self.assertLessEqual(used, available)
self.assertGreater(used, available * 0.99)

def test_constraint_respected(self):
mr, cfg, _ = self._run(10_000_000, page_size=1)
with mock_cpu_env():
config = cfg.calculate_pool_sizes_from_max_tokens(500, page_size=1)
self.assertEqual(config.max_total_num_tokens, 500)
self.assertEqual(config.swa_max_total_num_tokens, 500)


class TestFactory(unittest.TestCase):
def test_default_for_non_swa(self):
mr = _make_model_runner(is_hybrid_swa=False)
with mock_cpu_env():
from sglang.srt.model_executor.pool_configurator import (
DefaultPoolConfigurator,
create_memory_pool_configurator,
)

cfg = create_memory_pool_configurator(mr)
self.assertIsInstance(cfg, DefaultPoolConfigurator)

def test_swa_for_hybrid(self):
mr = _make_model_runner(
is_hybrid_swa=True,
full_attention_layer_ids=list(range(16)),
swa_attention_layer_ids=list(range(16, 32)),
swa_num_kv_heads=4,
)
with mock_cpu_env():
from sglang.srt.model_executor.pool_configurator import (
HybridSWAPoolConfigurator,
create_memory_pool_configurator,
)

cfg = create_memory_pool_configurator(mr)
self.assertIsInstance(cfg, HybridSWAPoolConfigurator)


if __name__ == "__main__":
unittest.main()
Loading