diff --git a/tests/test_admission.py b/tests/test_admission.py new file mode 100644 index 00000000..063bf42b --- /dev/null +++ b/tests/test_admission.py @@ -0,0 +1,553 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import time + +from unittest.mock import patch + +from vllm_mlx.admission import ( + AdmissionController, + MemoryMonitor, + RequestQueue, + compute_kv_per_token, +) + + +def test_kv_per_token_qwen35_35b(): + """Qwen3.5-35B: 10 attn layers, 2 KV heads, 256 head_dim, bfloat16.""" + result = compute_kv_per_token( + num_hidden_layers=40, + full_attention_interval=4, + num_kv_heads=2, + head_dim=256, + dtype_bytes=2, + ) + assert result == 20_480 # 10 * 2 * 256 * 2 * 2 + + +def test_kv_per_token_qwen35_122b(): + """Qwen3.5-122B: 12 attn layers, 2 KV heads, 256 head_dim, bfloat16.""" + result = compute_kv_per_token( + num_hidden_layers=48, + full_attention_interval=4, + num_kv_heads=2, + head_dim=256, + dtype_bytes=2, + ) + assert result == 24_576 # 12 * 2 * 256 * 2 * 2 + + +def test_kv_per_token_dense_model(): + """Dense model (no interval): all layers are attention.""" + result = compute_kv_per_token( + num_hidden_layers=32, + full_attention_interval=1, + num_kv_heads=8, + head_dim=128, + dtype_bytes=2, + ) + assert result == 32 * 8 * 128 * 2 * 2 # 131_072 + + +def test_kv_per_token_nemotron_h_hybrid_pattern(): + """Nemotron-H: 8 attention layers from hybrid_override_pattern, not 88.""" + pattern = "MEMEMEM*EMEMEMEM*EMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEM*EMEMEMEME" + result = compute_kv_per_token( + num_hidden_layers=88, + full_attention_interval=1, # default — would give 88 without pattern + num_kv_heads=2, + head_dim=128, + dtype_bytes=2, + hybrid_override_pattern=pattern, + ) + # 8 attention layers (count of '*' in pattern), not 88 + assert result == 8 * 2 * 128 * 2 * 2 # 8,192 + # Without pattern, it would be 88 * 2 * 128 * 2 * 2 = 90,112 (11x overestimate) + wrong = compute_kv_per_token( + num_hidden_layers=88, + full_attention_interval=1, + num_kv_heads=2, + head_dim=128, + dtype_bytes=2, + ) + assert wrong == 90_112 + assert wrong / result == 11.0 + + +def test_memory_monitor_free_memory(): + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 80 * 1024**3 + mock_mx.get_cache_memory.return_value = 10 * 1024**3 + monitor = MemoryMonitor() + free = monitor.free_memory() + assert free == 30 * 1024**3 # 120 - 80 - 10 + + +def test_memory_monitor_can_admit(): + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 95 * 1024**3 + mock_mx.get_cache_memory.return_value = 5 * 1024**3 + monitor = MemoryMonitor(headroom_bytes=8 * 1024**3) + # 20 GB free, 8 GB headroom, 5 GB prefill = 20 >= 13 → admit + assert monitor.can_admit(prefill_bytes=5 * 1024**3) is True + # 20 GB free, 8 GB headroom, 15 GB prefill = 20 < 23 → reject + assert monitor.can_admit(prefill_bytes=15 * 1024**3) is False + + +def test_request_queue_fifo(): + q = RequestQueue(policy="fifo") + q.enqueue("req-1", prompt_tokens=1000) + q.enqueue("req-2", prompt_tokens=500) + q.enqueue("req-3", prompt_tokens=2000) + # FIFO: order of insertion + assert q.peek().request_id == "req-1" + assert q.dequeue().request_id == "req-1" + assert q.dequeue().request_id == "req-2" + assert q.dequeue().request_id == "req-3" + assert q.is_empty() + + +def test_request_queue_length(): + q = RequestQueue(policy="fifo") + assert len(q) == 0 + q.enqueue("req-1", prompt_tokens=100) + assert len(q) == 1 + q.dequeue() + assert len(q) == 0 + + +def test_request_queue_cancel(): + q = RequestQueue(policy="fifo") + q.enqueue("req-1", prompt_tokens=100) + q.enqueue("req-2", prompt_tokens=200) + q.cancel("req-1") + assert q.dequeue().request_id == "req-2" + + +def test_admission_controller_admit_when_room(): + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 50 * 1024**3 # 70 GB free + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + ) + admitted, position = controller.try_admit("req-1", prompt_tokens=10_000) + assert admitted is True + assert position is None + + +def test_admission_controller_queue_when_full(): + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 # 5 GB free + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + ) + # 10K tokens * 20KB = 200MB prefill. 5GB free < 200MB + 8GB → queue + admitted, position = controller.try_admit("req-1", prompt_tokens=10_000) + assert admitted is False + assert position == 0 + + +def test_admission_controller_drain_queue(): + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + # Start full + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + ) + controller.try_admit("req-1", prompt_tokens=1000) + assert controller.queue_length == 1 + # Memory freed + mock_mx.get_active_memory.return_value = 50 * 1024**3 + ready = controller.check_queue() + assert len(ready) == 1 + assert ready[0].request_id == "req-1" + assert controller.queue_length == 0 + + +def test_admission_controller_fifo_no_bypass(): + """Small requests must queue behind large ones even if they'd fit.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 112 * 1024**3 # 8 GB free + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + ) + # Large request can't fit (~3.8GB prefill + 8GB headroom > 8GB free) + admitted1, pos1 = controller.try_admit("req-large", prompt_tokens=200_000) + assert admitted1 is False + assert pos1 == 0 + # Small request WOULD fit (20KB + 8GB < 8GB is false, but even if + # memory freed later, FIFO requires it to queue behind the large one) + admitted2, pos2 = controller.try_admit("req-small", prompt_tokens=1) + assert admitted2 is False + assert pos2 == 1 + + +def test_admission_controller_cancel(): + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + ) + controller.try_admit("req-1", prompt_tokens=1000) + assert controller.queue_length == 1 + assert controller.cancel("req-1") is True + assert controller.queue_length == 0 + assert controller.cancel("nonexistent") is False + + +def test_admission_controller_wait_for_admission(): + """Test async wait — admitted immediately when room.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 50 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + # Should return immediately — plenty of memory + asyncio.get_event_loop().run_until_complete( + controller.wait_for_admission("req-1", prompt_tokens=100) + ) + assert controller.queue_length == 0 + + +def test_admission_controller_on_request_complete(): + """Test that completing a request releases queued ones.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + + # Queue a request (not enough memory) + admitted, pos = controller.try_admit("req-1", prompt_tokens=1000) + assert admitted is False + + # Set up wait event manually + event = asyncio.Event() + controller._wait_events["req-1"] = event + + # Free memory and signal completion + mock_mx.get_active_memory.return_value = 50 * 1024**3 + ready = controller.on_request_complete() + assert len(ready) == 1 + assert ready[0].request_id == "req-1" + assert event.is_set() # Waiter was signaled + + +def test_admission_controller_cancel_wakes_waiter(): + """Test that cancelling a queued request wakes up its waiter.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + + # Queue a request + controller.try_admit("req-1", prompt_tokens=1000) + + # Set up wait event + event = asyncio.Event() + controller._wait_events["req-1"] = event + + # Cancel should wake the waiter + assert controller.cancel("req-1") is True + assert event.is_set() + assert "req-1" not in controller._wait_events + + +def test_admission_wait_cleans_up_on_cancellation(): + """Verify CancelledError during wait cleans up queue and events.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + + async def simulate_cancel(): + task = asyncio.create_task( + controller.wait_for_admission("req-cancel", prompt_tokens=1000) + ) + await asyncio.sleep(0) # Let it reach the await event.wait() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + # Queue and events should be cleaned up + assert controller.queue_length == 0 + assert "req-cancel" not in controller._wait_events + + asyncio.get_event_loop().run_until_complete(simulate_cancel()) + + +def test_admission_controller_eviction_callback(): + """When queue is non-empty and memory tight, evict prefix cache.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + + eviction_count = [] + + def mock_evict(): + eviction_count.append(1) + # Simulate freeing memory by reducing active + mock_mx.get_active_memory.return_value = 50 * 1024**3 + return True # True = evicted something + + controller.set_eviction_callback(mock_evict) + + # Queue a request (not enough memory) + controller.try_admit("req-1", prompt_tokens=1000) + assert controller.queue_length == 1 + + # check_queue should call eviction, then admit + ready = controller.check_queue() + assert len(eviction_count) > 0 # Eviction was called + assert len(ready) == 1 # Request was admitted after eviction + assert ready[0].request_id == "req-1" + + +def test_admission_controller_eviction_no_callback(): + """Without eviction callback, check_queue just waits.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + controller.try_admit("req-1", prompt_tokens=1000) + ready = controller.check_queue() + assert len(ready) == 0 # Can't admit, no eviction callback + + +def test_admission_controller_eviction_exhausted(): + """If eviction can't free enough memory, request stays queued.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, headroom_bytes=8 * 1024**3 + ) + + def mock_evict_nothing(): + return False # Nothing to evict + + controller.set_eviction_callback(mock_evict_nothing) + controller.try_admit("req-1", prompt_tokens=1000) + ready = controller.check_queue() + assert len(ready) == 0 # Still can't admit + assert controller.queue_length == 1 + + +# ===================================================================== +# Phase B3: Queue policy tests +# ===================================================================== + + +def test_request_queue_shortest_first_basic(): + """shortest_first dequeues smallest prompt first.""" + q = RequestQueue(policy="shortest_first") + q.enqueue("req-large", prompt_tokens=100_000) + q.enqueue("req-small", prompt_tokens=1_000) + q.enqueue("req-medium", prompt_tokens=50_000) + assert q.dequeue().request_id == "req-small" + assert q.dequeue().request_id == "req-medium" + assert q.dequeue().request_id == "req-large" + assert q.is_empty() + + +def test_request_queue_shortest_first_tiebreak_fifo(): + """shortest_first breaks ties by enqueue order.""" + q = RequestQueue(policy="shortest_first") + q.enqueue("req-a", prompt_tokens=1000) + q.enqueue("req-b", prompt_tokens=1000) + q.enqueue("req-c", prompt_tokens=1000) + assert q.dequeue().request_id == "req-a" + assert q.dequeue().request_id == "req-b" + assert q.dequeue().request_id == "req-c" + + +def test_request_queue_shortest_first_starvation_guard(): + """Requests waiting past starvation_timeout_s get dequeued first.""" + q = RequestQueue(policy="shortest_first", starvation_timeout_s=60.0) + q.enqueue("req-large", prompt_tokens=100_000) + q.enqueue("req-small", prompt_tokens=1_000) + # Simulate req-large has been waiting 120s (past 60s timeout) + q._queue[0].enqueued_at = time.time() - 120.0 + # Starved request gets effective_tokens=0, dequeued before req-small + assert q.dequeue().request_id == "req-large" + assert q.dequeue().request_id == "req-small" + + +def test_request_queue_shortest_first_peek(): + """peek() returns the same entry that dequeue() would.""" + q = RequestQueue(policy="shortest_first") + q.enqueue("req-large", prompt_tokens=100_000) + q.enqueue("req-small", prompt_tokens=1_000) + peeked = q.peek() + assert peeked.request_id == "req-small" + dequeued = q.dequeue() + assert dequeued.request_id == "req-small" + + +def test_request_queue_priority_basic(): + """priority dequeues highest priority first.""" + q = RequestQueue(policy="priority") + q.enqueue("req-low", prompt_tokens=1000, priority=0) + q.enqueue("req-high", prompt_tokens=1000, priority=10) + q.enqueue("req-mid", prompt_tokens=1000, priority=5) + assert q.dequeue().request_id == "req-high" + assert q.dequeue().request_id == "req-mid" + assert q.dequeue().request_id == "req-low" + + +def test_request_queue_priority_tiebreak_fifo(): + """Same-priority requests ordered FIFO.""" + q = RequestQueue(policy="priority") + q.enqueue("req-a", prompt_tokens=1000, priority=5) + q.enqueue("req-b", prompt_tokens=1000, priority=5) + q.enqueue("req-c", prompt_tokens=1000, priority=5) + assert q.dequeue().request_id == "req-a" + assert q.dequeue().request_id == "req-b" + assert q.dequeue().request_id == "req-c" + + +def test_request_queue_invalid_policy(): + """Invalid policy raises ValueError.""" + import pytest + + with pytest.raises(ValueError, match="Unsupported policy"): + RequestQueue(policy="round_robin") + + +def test_admission_controller_shortest_first_admits_small(): + """shortest_first admits a small request even when a large one can't fit.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + # 10 GB free (just enough for small, not large) + mock_mx.get_active_memory.return_value = 110 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + policy="shortest_first", + ) + # Both queued (neither fits when first checked — queue is empty so + # try_admit checks memory directly) + # Large: 200K * 20KB = 4GB prefill + 8GB headroom = 12GB > 10GB free → queue + admitted1, _ = controller.try_admit("req-large", prompt_tokens=200_000) + assert admitted1 is False + # Small: 100 * 20KB ≈ 2MB + 8GB headroom < 10GB → would fit, but + # queue is non-empty so it joins the queue + admitted2, _ = controller.try_admit("req-small", prompt_tokens=100) + assert admitted2 is False + assert controller.queue_length == 2 + + # Now memory frees to 20GB — shortest_first admits small first + mock_mx.get_active_memory.return_value = 100 * 1024**3 + ready = controller.on_request_complete() + # Both should be admittable now, but small is dequeued first + assert len(ready) >= 1 + assert ready[0].request_id == "req-small" + + +def test_admission_controller_priority_order(): + """priority policy admits highest-priority request first.""" + with patch("vllm_mlx.admission.mx") as mock_mx: + mock_mx.metal.is_available.return_value = True + mock_mx.device_info.return_value = { + "max_recommended_working_set_size": 120 * 1024**3 + } + mock_mx.get_active_memory.return_value = 115 * 1024**3 + mock_mx.get_cache_memory.return_value = 0 + controller = AdmissionController( + kv_per_token=20_480, + headroom_bytes=8 * 1024**3, + policy="priority", + ) + controller.try_admit("req-low", prompt_tokens=1000, priority=0) + controller.try_admit("req-high", prompt_tokens=1000, priority=10) + controller.try_admit("req-mid", prompt_tokens=1000, priority=5) + assert controller.queue_length == 3 + + # Free memory — highest priority dequeued first + mock_mx.get_active_memory.return_value = 50 * 1024**3 + ready = controller.on_request_complete() + assert len(ready) == 3 + assert ready[0].request_id == "req-high" + assert ready[1].request_id == "req-mid" + assert ready[2].request_id == "req-low" diff --git a/vllm_mlx/admission.py b/vllm_mlx/admission.py new file mode 100644 index 00000000..00f9d318 --- /dev/null +++ b/vllm_mlx/admission.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Memory-aware admission controller for multi-user inference. + +Core principle: Load affects latency, never quality. +Once a request starts generating, it runs to completion. +The admission controller only decides WHEN to start. +""" + +import asyncio +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def compute_kv_per_token( + num_hidden_layers: int, + full_attention_interval: int, + num_kv_heads: int, + head_dim: int, + dtype_bytes: int = 2, + hybrid_override_pattern: Optional[str] = None, +) -> int: + """Compute KV cache bytes per token for this model. + + Only full attention layers contribute to KV cache. + Linear attention (GatedDeltaNet) layers use fixed-size + SSM state regardless of context length. + + Args: + num_hidden_layers: Total layer count. + full_attention_interval: Every Nth layer is full attention. + 1 = all attention (dense), 4 = 25% attention (Qwen3.5). + num_kv_heads: Number of KV attention heads. + head_dim: Dimension per head. + dtype_bytes: Bytes per element (2 for bfloat16, 1 for int8 quantized). + hybrid_override_pattern: Layer-type string from config.json (e.g. + Nemotron-H "MEMEMEM*E..."). '*' = attention, 'M' = Mamba, + 'E' = MoE, '-' = MLP. When provided, attention layers are + counted directly from the pattern instead of using + full_attention_interval. + + Returns: + Bytes of KV cache consumed per token. + """ + if hybrid_override_pattern: + attention_layers = hybrid_override_pattern.count("*") + else: + if full_attention_interval <= 0: + full_attention_interval = 1 + attention_layers = num_hidden_layers // full_attention_interval + return attention_layers * num_kv_heads * head_dim * 2 * dtype_bytes # 2 = K + V + + +class MemoryMonitor: + """Reads actual Metal memory to make admission decisions. + + Uses mx.get_active_memory() + mx.get_cache_memory() for true GPU + memory usage, not just live tensors. + """ + + def __init__(self, headroom_bytes: int = 8 * 1024**3): + self.headroom_bytes = headroom_bytes + if mx.metal.is_available(): + info = mx.device_info() + self._device_usable = info["max_recommended_working_set_size"] + else: + self._device_usable = 0 + + def free_memory(self) -> int: + """Return bytes of free GPU-usable memory.""" + if not mx.metal.is_available(): + return 0 + return max( + 0, self._device_usable - mx.get_active_memory() - mx.get_cache_memory() + ) + + def can_admit(self, prefill_bytes: int) -> bool: + """Can we admit a request that needs prefill_bytes of KV cache?""" + return self.free_memory() >= prefill_bytes + self.headroom_bytes + + +_VALID_POLICIES = ("fifo", "shortest_first", "priority") + + +@dataclass +class QueuedRequest: + request_id: str + prompt_tokens: int + priority: int = 0 # Higher = more important (only used by "priority" policy) + enqueued_at: float = field(default_factory=time.time) + + +class RequestQueue: + """Request queue with configurable ordering policy. + + Policies: + fifo: First-in-first-out. Fair, prevents starvation. Head-of-line + blocking: a large request at the front blocks smaller ones behind it. + shortest_first: Dequeue the request with the fewest prompt tokens. + Maximizes throughput by clearing short requests first. Starvation + guard: requests waiting longer than ``starvation_timeout_s`` are + promoted to highest priority. + priority: Dequeue the request with the highest ``priority`` value. + Same-priority requests are ordered FIFO (by enqueue time). + """ + + def __init__(self, policy: str = "fifo", starvation_timeout_s: float = 120.0): + if policy not in _VALID_POLICIES: + raise ValueError( + f"Unsupported policy: {policy!r}. " + f"Valid policies: {', '.join(_VALID_POLICIES)}" + ) + self._policy = policy + self._starvation_timeout_s = starvation_timeout_s + self._queue: deque[QueuedRequest] = deque() + + @property + def policy(self) -> str: + return self._policy + + def enqueue(self, request_id: str, prompt_tokens: int, priority: int = 0) -> int: + """Add request to queue. Returns position (0-indexed).""" + entry = QueuedRequest( + request_id=request_id, prompt_tokens=prompt_tokens, priority=priority + ) + self._queue.append(entry) + position = len(self._queue) - 1 + logger.info( + f"[queue] {request_id} queued at position {position} " + f"({prompt_tokens} tokens, priority={priority})" + ) + return position + + def peek(self) -> Optional[QueuedRequest]: + """Return the next request that would be dequeued, without removing it.""" + if not self._queue: + return None + if self._policy == "fifo": + return self._queue[0] + return self._select_next() + + def dequeue(self) -> Optional[QueuedRequest]: + """Remove and return next request per policy.""" + if not self._queue: + return None + if self._policy == "fifo": + return self._queue.popleft() + entry = self._select_next() + if entry is not None: + self._queue.remove(entry) + return entry + + def _select_next(self) -> Optional[QueuedRequest]: + """Select the next request to dequeue based on policy. + + For shortest_first: pick the request with fewest prompt_tokens. + Starvation guard: any request waiting longer than starvation_timeout_s + is treated as having 0 tokens (highest priority to dequeue). + + For priority: pick the request with highest priority value. + Ties broken by earliest enqueue time (FIFO within same priority). + """ + if not self._queue: + return None + + now = time.time() + + if self._policy == "shortest_first": + best = None + for entry in self._queue: + waited = now - entry.enqueued_at + starved = waited >= self._starvation_timeout_s + # Starved requests get effective size 0 (dequeue first) + effective_tokens = 0 if starved else entry.prompt_tokens + if best is None: + best = (entry, effective_tokens, entry.enqueued_at) + else: + _, best_tokens, best_time = best + # Prefer fewer tokens; break ties by earlier enqueue + if effective_tokens < best_tokens or ( + effective_tokens == best_tokens + and entry.enqueued_at < best_time + ): + best = (entry, effective_tokens, entry.enqueued_at) + return best[0] if best else None + + elif self._policy == "priority": + best = None + for entry in self._queue: + if best is None: + best = entry + elif entry.priority > best.priority: + best = entry + elif ( + entry.priority == best.priority + and entry.enqueued_at < best.enqueued_at + ): + best = entry + return best + + # Fallback (shouldn't reach here — FIFO handled in dequeue) + return self._queue[0] + + def cancel(self, request_id: str) -> bool: + """Remove a request from the queue. Returns True if found.""" + for i, entry in enumerate(self._queue): + if entry.request_id == request_id: + del self._queue[i] + return True + return False + + def is_empty(self) -> bool: + return len(self._queue) == 0 + + def __len__(self) -> int: + return len(self._queue) + + +class AdmissionController: + """Flow control for multi-user inference. + + Core principle: Load affects latency, never quality. + Decides WHEN to start requests. Once started, a request + runs to completion at full quality. + """ + + def __init__( + self, + kv_per_token: int, + headroom_bytes: int = 8 * 1024**3, + policy: str = "fifo", + ): + self.kv_per_token = kv_per_token + self._monitor = MemoryMonitor(headroom_bytes=headroom_bytes) + self._queue = RequestQueue(policy=policy) + self._wait_events: dict[str, asyncio.Event] = {} + self._eviction_callback: Optional[callable] = None + + def try_admit( + self, request_id: str, prompt_tokens: int, priority: int = 0 + ) -> Tuple[bool, Optional[int]]: + """Try to admit a request for immediate processing. + + Returns: + (True, None) if admitted. + (False, queue_position) if queued. + """ + # FIFO: if anything is already waiting, join the queue regardless of memory. + # (shortest_first and priority also respect queue — no cutting the line) + if not self._queue.is_empty(): + position = self._queue.enqueue(request_id, prompt_tokens, priority) + return False, position + prefill_bytes = prompt_tokens * self.kv_per_token + if self._monitor.can_admit(prefill_bytes): + logger.info( + f"[admit] {request_id} ADMITTED ({prompt_tokens} tokens, " + f"{prefill_bytes / 1e6:.0f} MB prefill, " + f"{self._monitor.free_memory() / 1e9:.1f} GB free)" + ) + return True, None + position = self._queue.enqueue(request_id, prompt_tokens, priority) + return False, position + + def set_eviction_callback(self, callback) -> None: + """Set a callback to evict prefix cache entries under memory pressure. + + The callback should return True if it evicted something, False if + there's nothing left to evict. + """ + self._eviction_callback = callback + + def check_queue(self) -> List[QueuedRequest]: + """Check if queued requests can now be admitted. + + Called after a request completes or memory is freed. + + Policy behavior: + fifo: Only admit the front of the queue. Never skips. + shortest_first: Admit the smallest request that fits, even if + it's not at the front. Starved requests (past timeout) are + treated as smallest. + priority: Admit the highest-priority request that fits. Ties + broken by FIFO order. + + If an eviction callback is registered and the selected request can't + fit, the callback is invoked to free prefix cache memory. + + Returns list of newly-admittable requests. + """ + ready = [] + + if self._queue.policy == "fifo": + # FIFO: strict head-of-line. Never skip the front. + while not self._queue.is_empty(): + entry = self._queue.peek() + prefill_bytes = entry.prompt_tokens * self.kv_per_token + if self._monitor.can_admit(prefill_bytes): + ready.append(self._queue.dequeue()) + logger.info( + f"[admit] {entry.request_id} DEQUEUED → admitted " + f"(waited {time.time() - entry.enqueued_at:.1f}s)" + ) + else: + if ( + self._eviction_callback is not None + and self._eviction_callback() + ): + continue + break + else: + # shortest_first / priority: scan for best admittable request + while not self._queue.is_empty(): + candidate = self._queue.peek() # Policy-ordered best + prefill_bytes = candidate.prompt_tokens * self.kv_per_token + if self._monitor.can_admit(prefill_bytes): + self._queue.dequeue() + ready.append(candidate) + logger.info( + f"[admit] {candidate.request_id} DEQUEUED → admitted " + f"(waited {time.time() - candidate.enqueued_at:.1f}s, " + f"policy={self._queue.policy})" + ) + else: + # Best candidate doesn't fit. Try eviction once. + if ( + self._eviction_callback is not None + and self._eviction_callback() + ): + continue + break + + return ready + + async def wait_for_admission( + self, + request_id: str, + prompt_tokens: int, + priority: int = 0, + timeout: float = 600.0, + ) -> None: + """Block until this request can be admitted. Returns immediately if room. + + Raises asyncio.TimeoutError after *timeout* seconds (default 600). + Raises asyncio.CancelledError if cancel() is called for this request. + The caller MUST call on_request_complete() in a finally block after + the request finishes generating. + """ + admitted, position = self.try_admit(request_id, prompt_tokens, priority) + if admitted: + return + event = asyncio.Event() + self._wait_events[request_id] = event + self._cancelled_ids: set = getattr(self, "_cancelled_ids", set()) + logger.info(f"[admit] {request_id} waiting for admission (position {position})") + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + self._queue.cancel(request_id) + self._wait_events.pop(request_id, None) + raise + except asyncio.CancelledError: + self._queue.cancel(request_id) + self._wait_events.pop(request_id, None) + raise + self._wait_events.pop(request_id, None) + # Distinguish "admitted" from "cancelled" — cancel() marks the id + if request_id in self._cancelled_ids: + self._cancelled_ids.discard(request_id) + raise asyncio.CancelledError(f"Request {request_id} was cancelled") + + def on_request_complete(self) -> List[QueuedRequest]: + """Called when a request completes. Drains queue and signals waiters.""" + ready = self.check_queue() + for entry in ready: + event = self._wait_events.get(entry.request_id) + if event: + event.set() + return ready + + def cancel(self, request_id: str) -> bool: + """Cancel a queued request. The waiter gets asyncio.CancelledError.""" + result = self._queue.cancel(request_id) + self._cancelled_ids: set = getattr(self, "_cancelled_ids", set()) + self._cancelled_ids.add(request_id) + event = self._wait_events.pop(request_id, None) + if event: + event.set() # Wake up waiter — it checks _cancelled_ids and raises + return result + + @property + def queue_length(self) -> int: + return len(self._queue)