diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 840bb069..1af8c9d9 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -754,7 +754,16 @@ def get_stats(self) -> dict[str, Any]: } if self._mllm_scheduler: - stats["mllm_scheduler"] = self._mllm_scheduler.get_stats() + mllm_stats = self._mllm_scheduler.get_stats() + stats["mllm_scheduler"] = mllm_stats + # Promote Metal memory stats to top-level for /v1/status + for key in ( + "metal_active_memory_gb", + "metal_peak_memory_gb", + "metal_cache_memory_gb", + ): + if key in mllm_stats: + stats[key] = mllm_stats[key] elif self._engine: stats.update(self._engine.get_stats()) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 15951321..ed118bbb 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -415,13 +415,26 @@ def run_stream(): def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" - return { + stats = { "engine_type": "simple", "model_name": self._model_name, "is_mllm": self._is_mllm, "loaded": self._loaded, } + # Include Metal memory stats + try: + import mlx.core as mx + + if mx.metal.is_available(): + stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2) + stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2) + stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2) + except Exception: + pass + + return stats + def get_cache_stats(self) -> dict[str, Any] | None: """Get cache statistics (for MLLM models).""" if self._is_mllm and self._model is not None: diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 764f0543..a987f706 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -26,15 +26,16 @@ from dataclasses import dataclass, field from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple +import mlx.core as mx from .mllm_batch_generator import ( MLLMBatchGenerator, MLLMBatchRequest, MLLMBatchResponse, ) +from .mllm_cache import MLLMCacheManager from .multimodal_processor import MultimodalProcessor from .request import RequestOutput, RequestStatus, SamplingParams -from .mllm_cache import MLLMCacheManager logger = logging.getLogger(__name__) @@ -753,6 +754,15 @@ def get_stats(self) -> Dict[str, Any]: if self.vision_cache: stats["vision_cache"] = self.vision_cache.get_stats() + # Include Metal memory stats + try: + if mx.metal.is_available(): + stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2) + stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2) + stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2) + except Exception: + pass + return stats def reset(self) -> None: