diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3ac52b4b0..4d1059243 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -768,14 +768,27 @@ def get_stats(self) -> dict[str, Any]: if self._mllm_scheduler: mllm_stats = self._mllm_scheduler.get_stats() stats["mllm_scheduler"] = mllm_stats - # Promote Metal memory stats to top-level for /v1/status + # Promote stats to top-level for /v1/status and monitoring for key in ( + "running", + "num_running", + "num_waiting", + "num_requests_processed", + "total_prompt_tokens", + "total_completion_tokens", "metal_active_memory_gb", "metal_peak_memory_gb", "metal_cache_memory_gb", + "memory_aware_cache", + "paged_cache", + "prefix_cache", + "requests", ): if key in mllm_stats: stats[key] = mllm_stats[key] + # MLLM engine is always "running" once loaded + if "running" not in stats: + stats["running"] = self._loaded elif self._engine: stats.update(self._engine.get_stats()) diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..027fed11e 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -94,6 +94,9 @@ class MLLMRequest: num_prompt_tokens: int = 0 num_output_tokens: int = 0 + # Timing + first_token_time: Optional[float] = None + @dataclass class MLLMSchedulerOutput: @@ -457,6 +460,9 @@ def _process_batch_responses( request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) + if request.first_token_time is None and request.num_output_tokens > 0: + request.first_token_time = time.time() + # Decode the new token using streaming detokenizer (UTF-8 safe). # Skip stop tokens — they are not content. if response.finish_reason == "stop": @@ -778,6 +784,70 @@ async def generate( # ========== Stats and utilities ========== + def get_running_requests_info(self) -> List[Dict[str, Any]]: + """Per-request details for status endpoint.""" + now = time.time() + result = [] + + # Waiting requests + for req in self.waiting: + result.append( + { + "request_id": req.request_id, + "status": "waiting", + "phase": "queued", + "elapsed_s": round(now - req.arrival_time, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": 0, + "max_tokens": req.sampling_params.max_tokens, + "progress": 0.0, + "tokens_per_second": None, + "ttft_s": None, + "cache_hit_type": None, + "cached_tokens": 0, + } + ) + + # Running requests + for req in self.running.values(): + n_out = req.num_output_tokens + elapsed = now - req.arrival_time + + if n_out == 0: + phase = "prefill" + else: + phase = "generation" + + tok_s = None + ttft = None + if req.first_token_time is not None: + ttft = round(req.first_token_time - req.arrival_time, 3) + gen_elapsed = now - req.first_token_time + if gen_elapsed > 0 and n_out > 0: + tok_s = round(n_out / gen_elapsed, 1) + + max_tokens = req.sampling_params.max_tokens + progress = round(n_out / max_tokens, 3) if max_tokens > 0 else 0.0 + + result.append( + { + "request_id": req.request_id, + "status": "running", + "phase": phase, + "elapsed_s": round(elapsed, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": n_out, + "max_tokens": max_tokens, + "progress": min(progress, 1.0), + "tokens_per_second": tok_s, + "ttft_s": ttft, + "cache_hit_type": None, + "cached_tokens": 0, + } + ) + + return result + def get_stats(self) -> Dict[str, Any]: """Get scheduler statistics.""" stats = { @@ -787,6 +857,7 @@ def get_stats(self) -> Dict[str, Any]: "num_requests_processed": self.num_requests_processed, "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, + "requests": self.get_running_requests_info(), } if self.batch_generator is not None: