Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_mlx/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,4 @@ class ChatCompletionChunk(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionChunkChoice]
usage: Usage | None = None
24 changes: 20 additions & 4 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ async def stream_generate(
await self.start()

accumulated_text = ""
token_count = 0
prompt_tokens = 0
completion_tokens = 0
finished = False

for chunk in self._model.stream_generate(
prompt=prompt,
Expand All @@ -180,26 +182,40 @@ async def stream_generate(
stop=stop,
**kwargs,
):
token_count += 1
prompt_tokens = chunk.prompt_tokens if hasattr(chunk, 'prompt_tokens') else prompt_tokens
completion_tokens += 1
new_text = chunk.text if hasattr(chunk, 'text') else str(chunk)
accumulated_text += new_text

finished = getattr(chunk, 'finished', False) or token_count >= max_tokens
finished = getattr(chunk, 'finished', False) or completion_tokens >= max_tokens
finish_reason = None
if finished:
finish_reason = getattr(chunk, 'finish_reason', 'stop')

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
completion_tokens=token_count,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=finished,
finish_reason=finish_reason,
)

if finished:
break

if not finished:
if prompt_tokens == 0:
prompt_tokens = len(self._model.tokenizer.encode(prompt))
yield GenerationOutput(
text=accumulated_text,
new_text="",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finished=True,
finish_reason=None,
)

async def chat(
self,
messages: List[Dict[str, Any]],
Expand Down
24 changes: 21 additions & 3 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from typing import AsyncIterator, Optional

from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

# Import from new modular API
Expand All @@ -64,7 +65,7 @@
parse_tool_calls)
from .api.utils import (clean_output_text, extract_multimodal_content,
is_mllm_model)
from .engine import BaseEngine, BatchedEngine, SimpleEngine
from .engine import BaseEngine, BatchedEngine, SimpleEngine, GenerationOutput

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -264,6 +265,17 @@ def load_model(
logger.info(f"Default max tokens: {_default_max_tokens}")


def get_usage(output: GenerationOutput) -> Usage:
"""Extract usage metrics from GenerationOutput."""
total_prompt_tokens = output.prompt_tokens if hasattr(output, 'prompt_tokens') else 0
total_completion_tokens = output.completion_tokens if hasattr(output, 'completion_tokens') else 0
return Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)


@app.get("/health")
async def health():
"""Health check endpoint."""
Expand Down Expand Up @@ -542,6 +554,7 @@ async def create_completion(request: CompletionRequest):
timeout = request.timeout or _default_timeout
choices = []
total_completion_tokens = 0
total_prompt_tokens = 0

for i, prompt in enumerate(prompts):
try:
Expand All @@ -568,19 +581,21 @@ async def create_completion(request: CompletionRequest):
)
)
total_completion_tokens += output.completion_tokens
total_prompt_tokens += output.prompt_tokens if hasattr(output, 'prompt_tokens') else 0

elapsed = time.perf_counter() - start_time
tokens_per_sec = total_completion_tokens / elapsed if elapsed > 0 else 0
logger.info(
f"Completion: {total_completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
f"Completion: {total_prompt_tokens} prompt + {total_completion_tokens} completion tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
)

return CompletionResponse(
model=request.model,
choices=choices,
usage=Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
),
)

Expand Down Expand Up @@ -790,6 +805,8 @@ async def stream_completion(
}
],
}
if output.finished:
data["usage"] = get_usage(output).model_dump()
yield f"data: {json.dumps(data)}\n\n"

yield "data: [DONE]\n\n"
Expand Down Expand Up @@ -841,6 +858,7 @@ async def stream_chat_completion(
finish_reason=output.finish_reason if output.finished else None,
)
],
usage=get_usage(output) if output.finished else None,
)
yield f"data: {chunk.model_dump_json()}\n\n"

Expand Down