Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,16 @@ def get_stats(self) -> dict[str, Any]:
}

if self._mllm_scheduler:
stats["mllm_scheduler"] = self._mllm_scheduler.get_stats()
mllm_stats = self._mllm_scheduler.get_stats()
stats["mllm_scheduler"] = mllm_stats
# Promote Metal memory stats to top-level for /v1/status
for key in (
"metal_active_memory_gb",
"metal_peak_memory_gb",
"metal_cache_memory_gb",
):
if key in mllm_stats:
stats[key] = mllm_stats[key]
elif self._engine:
stats.update(self._engine.get_stats())

Expand Down
15 changes: 14 additions & 1 deletion vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,26 @@ def run_stream():

def get_stats(self) -> dict[str, Any]:
"""Get engine statistics."""
return {
stats = {
"engine_type": "simple",
"model_name": self._model_name,
"is_mllm": self._is_mllm,
"loaded": self._loaded,
}

# Include Metal memory stats
try:
import mlx.core as mx

if mx.metal.is_available():
stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2)
stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2)
stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2)
except Exception:
pass

return stats

def get_cache_stats(self) -> dict[str, Any] | None:
"""Get cache statistics (for MLLM models)."""
if self._is_mllm and self._model is not None:
Expand Down
12 changes: 11 additions & 1 deletion vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple

import mlx.core as mx

from .mllm_batch_generator import (
MLLMBatchGenerator,
MLLMBatchRequest,
MLLMBatchResponse,
)
from .mllm_cache import MLLMCacheManager
from .multimodal_processor import MultimodalProcessor
from .request import RequestOutput, RequestStatus, SamplingParams
from .mllm_cache import MLLMCacheManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -753,6 +754,15 @@ def get_stats(self) -> Dict[str, Any]:
if self.vision_cache:
stats["vision_cache"] = self.vision_cache.get_stats()

# Include Metal memory stats
try:
if mx.metal.is_available():
stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2)
stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2)
stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2)
except Exception:
pass

return stats

def reset(self) -> None:
Expand Down
Loading