Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,14 +768,27 @@ def get_stats(self) -> dict[str, Any]:
if self._mllm_scheduler:
mllm_stats = self._mllm_scheduler.get_stats()
stats["mllm_scheduler"] = mllm_stats
# Promote Metal memory stats to top-level for /v1/status
# Promote stats to top-level for /v1/status and monitoring
for key in (
"running",
"num_running",
"num_waiting",
"num_requests_processed",
"total_prompt_tokens",
"total_completion_tokens",
"metal_active_memory_gb",
"metal_peak_memory_gb",
"metal_cache_memory_gb",
"memory_aware_cache",
"paged_cache",
"prefix_cache",
"requests",
):
if key in mllm_stats:
stats[key] = mllm_stats[key]
# MLLM engine is always "running" once loaded
if "running" not in stats:
stats["running"] = self._loaded
elif self._engine:
stats.update(self._engine.get_stats())

Expand Down
71 changes: 71 additions & 0 deletions vllm_mlx/mllm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class MLLMRequest:
num_prompt_tokens: int = 0
num_output_tokens: int = 0

# Timing
first_token_time: Optional[float] = None


@dataclass
class MLLMSchedulerOutput:
Expand Down Expand Up @@ -457,6 +460,9 @@ def _process_batch_responses(
request.output_tokens.append(response.token)
request.num_output_tokens = len(request.output_tokens)

if request.first_token_time is None and request.num_output_tokens > 0:
request.first_token_time = time.time()

# Decode the new token using streaming detokenizer (UTF-8 safe).
# Skip stop tokens — they are not content.
if response.finish_reason == "stop":
Expand Down Expand Up @@ -778,6 +784,70 @@ async def generate(

# ========== Stats and utilities ==========

def get_running_requests_info(self) -> List[Dict[str, Any]]:
"""Per-request details for status endpoint."""
now = time.time()
result = []

# Waiting requests
for req in self.waiting:
result.append(
{
"request_id": req.request_id,
"status": "waiting",
"phase": "queued",
"elapsed_s": round(now - req.arrival_time, 2),
"prompt_tokens": req.num_prompt_tokens,
"completion_tokens": 0,
"max_tokens": req.sampling_params.max_tokens,
"progress": 0.0,
"tokens_per_second": None,
"ttft_s": None,
"cache_hit_type": None,
"cached_tokens": 0,
}
)

# Running requests
for req in self.running.values():
n_out = req.num_output_tokens
elapsed = now - req.arrival_time

if n_out == 0:
phase = "prefill"
else:
phase = "generation"

tok_s = None
ttft = None
if req.first_token_time is not None:
ttft = round(req.first_token_time - req.arrival_time, 3)
gen_elapsed = now - req.first_token_time
if gen_elapsed > 0 and n_out > 0:
tok_s = round(n_out / gen_elapsed, 1)

max_tokens = req.sampling_params.max_tokens
progress = round(n_out / max_tokens, 3) if max_tokens > 0 else 0.0

result.append(
{
"request_id": req.request_id,
"status": "running",
"phase": phase,
"elapsed_s": round(elapsed, 2),
"prompt_tokens": req.num_prompt_tokens,
"completion_tokens": n_out,
"max_tokens": max_tokens,
"progress": min(progress, 1.0),
"tokens_per_second": tok_s,
"ttft_s": ttft,
"cache_hit_type": None,
"cached_tokens": 0,
}
)

return result

def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
stats = {
Expand All @@ -787,6 +857,7 @@ def get_stats(self) -> Dict[str, Any]:
"num_requests_processed": self.num_requests_processed,
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
"requests": self.get_running_requests_info(),
}

if self.batch_generator is not None:
Expand Down
Loading