diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 689cf8a4f..ede597be1 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -336,3 +336,4 @@ class ChatCompletionChunk(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionChunkChoice] + usage: Usage | None = None diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index b4a72c3ab..5c8c75c01 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -170,7 +170,9 @@ async def stream_generate( await self.start() accumulated_text = "" - token_count = 0 + prompt_tokens = 0 + completion_tokens = 0 + finished = False for chunk in self._model.stream_generate( prompt=prompt, @@ -180,11 +182,12 @@ async def stream_generate( stop=stop, **kwargs, ): - token_count += 1 + prompt_tokens = chunk.prompt_tokens if hasattr(chunk, 'prompt_tokens') else prompt_tokens + completion_tokens += 1 new_text = chunk.text if hasattr(chunk, 'text') else str(chunk) accumulated_text += new_text - finished = getattr(chunk, 'finished', False) or token_count >= max_tokens + finished = getattr(chunk, 'finished', False) or completion_tokens >= max_tokens finish_reason = None if finished: finish_reason = getattr(chunk, 'finish_reason', 'stop') @@ -192,7 +195,8 @@ async def stream_generate( yield GenerationOutput( text=accumulated_text, new_text=new_text, - completion_tokens=token_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, finished=finished, finish_reason=finish_reason, ) @@ -200,6 +204,18 @@ async def stream_generate( if finished: break + if not finished: + if prompt_tokens == 0: + prompt_tokens = len(self._model.tokenizer.encode(prompt)) + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finished=True, + finish_reason=None, + ) + async def chat( self, messages: List[Dict[str, Any]], diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index efb8f2203..b4e7e8f84 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -49,6 +49,7 @@ from typing import AsyncIterator, Optional from fastapi import FastAPI, HTTPException, UploadFile +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse # Import from new modular API @@ -64,7 +65,7 @@ parse_tool_calls) from .api.utils import (clean_output_text, extract_multimodal_content, is_mllm_model) -from .engine import BaseEngine, BatchedEngine, SimpleEngine +from .engine import BaseEngine, BatchedEngine, SimpleEngine, GenerationOutput logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -264,6 +265,17 @@ def load_model( logger.info(f"Default max tokens: {_default_max_tokens}") +def get_usage(output: GenerationOutput) -> Usage: + """Extract usage metrics from GenerationOutput.""" + total_prompt_tokens = output.prompt_tokens if hasattr(output, 'prompt_tokens') else 0 + total_completion_tokens = output.completion_tokens if hasattr(output, 'completion_tokens') else 0 + return Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + ) + + @app.get("/health") async def health(): """Health check endpoint.""" @@ -542,6 +554,7 @@ async def create_completion(request: CompletionRequest): timeout = request.timeout or _default_timeout choices = [] total_completion_tokens = 0 + total_prompt_tokens = 0 for i, prompt in enumerate(prompts): try: @@ -568,19 +581,21 @@ async def create_completion(request: CompletionRequest): ) ) total_completion_tokens += output.completion_tokens + total_prompt_tokens += output.prompt_tokens if hasattr(output, 'prompt_tokens') else 0 elapsed = time.perf_counter() - start_time tokens_per_sec = total_completion_tokens / elapsed if elapsed > 0 else 0 logger.info( - f"Completion: {total_completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + f"Completion: {total_prompt_tokens} prompt + {total_completion_tokens} completion tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) return CompletionResponse( model=request.model, choices=choices, usage=Usage( + prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, - total_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, ), ) @@ -790,6 +805,8 @@ async def stream_completion( } ], } + if output.finished: + data["usage"] = get_usage(output).model_dump() yield f"data: {json.dumps(data)}\n\n" yield "data: [DONE]\n\n" @@ -841,6 +858,7 @@ async def stream_chat_completion( finish_reason=output.finish_reason if output.finished else None, ) ], + usage=get_usage(output) if output.finished else None, ) yield f"data: {chunk.model_dump_json()}\n\n"