diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index e96317ef..da3ccfc1 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -369,7 +369,7 @@ async def stream_generate( ): prompt_tokens = ( chunk.prompt_tokens - if hasattr(chunk, "prompt_tokens") + if hasattr(chunk, "prompt_tokens") and chunk.prompt_tokens else prompt_tokens ) completion_tokens += 1 @@ -472,9 +472,20 @@ async def chat( **kwargs, ) text = clean_output_text(output.text) + # Count prompt tokens from the full templated prompt + tokenizer = self._model.tokenizer + template_kwargs = { + "tokenize": True, + "add_generation_prompt": True, + } + if template_tools: + template_kwargs["tools"] = template_tools + prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) + prompt_token_count = len(prompt_ids) return GenerationOutput( text=text, tokens=output.tokens, + prompt_tokens=prompt_token_count, completion_tokens=len(output.tokens), finish_reason=output.finish_reason, ) diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 72182037..75bbab85 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -30,6 +30,7 @@ class StreamingOutput: token: int finished: bool = False finish_reason: str | None = None + prompt_tokens: int = 0 class MLXLanguageModel: @@ -203,6 +204,9 @@ def stream_generate( # Create sampler with parameters sampler = self._create_sampler(temperature, top_p) + # Count prompt tokens once upfront + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + token_count = 0 accumulated_text = "" @@ -241,6 +245,7 @@ def stream_generate( token=response.token if hasattr(response, "token") else 0, finished=finished, finish_reason=finish_reason, + prompt_tokens=num_prompt_tokens, ) if finished: