diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 38dc3e4f9..80c9f39c2 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -205,6 +205,10 @@ def __init__( self._running = False self._processing_task: Optional[asyncio.Task] = None + # Memory management: periodic mx.clear_cache() to free Metal buffer pool + self._step_count = 0 + self._clear_cache_interval = 32 + # Statistics self.num_requests_processed = 0 self.total_prompt_tokens = 0 @@ -550,6 +554,18 @@ def step(self) -> MLLMSchedulerOutput: if finished_ids: mx.clear_cache() + # Adaptive periodic cache clear: scale inversely with concurrency + # to prevent Metal buffer pool growth during long generations + active_seqs = len(self.running) + min_interval = max(4, self._clear_cache_interval // 4) + effective_interval = max( + min_interval, self._clear_cache_interval // max(1, active_seqs // 8) + ) + + self._step_count += 1 + if self._step_count % effective_interval == 0: + mx.clear_cache() + # Clear finished tracking for next step self.finished_req_ids = set()