diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 7202f625f..a95196633 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -38,6 +38,27 @@ def generate_side_effect(**kwargs): return result model.generate = MagicMock(side_effect=generate_side_effect) + + # stream_generate tracks concurrency the same way so tests that + # exercise SimpleEngine.generate() (which is now an accumulator + # over stream_generate) see the same serialization behavior. + def stream_generate_side_effect(**kwargs): + model._concurrent_count += 1 + model._max_concurrent = max(model._max_concurrent, model._concurrent_count) + import time + + time.sleep(0.05) + model._concurrent_count -= 1 + chunk = MagicMock() + chunk.text = "test response" + chunk.tokens = [1, 2, 3] + chunk.finished = True + chunk.finish_reason = "stop" + chunk.prompt_tokens = 3 + chunk.completion_tokens = 3 + yield chunk + + model.stream_generate = MagicMock(side_effect=stream_generate_side_effect) return model @pytest.fixture @@ -213,3 +234,86 @@ async def test_requests_complete_in_order(self, mock_model): assert len(results) == 3 for result in results: assert result.text == "test response" + + @pytest.mark.asyncio + async def test_generate_accumulates_over_stream_generate(self): + """generate() should iterate stream_generate() and return the last + yielded GenerationOutput, forwarding per-request kwargs (including + SpecPrefill overrides) through so they reach _stream_generate_specprefill. + """ + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + + async def fake_stream_generate(**kwargs): + captured_kwargs.update(kwargs) + # First chunk: mid-generation + yield GenerationOutput( + text="partial", + new_text="partial", + tokens=[1, 2], + prompt_tokens=11, + completion_tokens=2, + finished=False, + finish_reason=None, + ) + # Final chunk: finished + yield GenerationOutput( + text="partial final", + new_text=" final", + tokens=[1, 2, 3], + prompt_tokens=11, + completion_tokens=3, + finished=True, + finish_reason="stop", + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + output = await engine.generate( + prompt="say hi", + max_tokens=16, + temperature=0.6, + top_p=0.95, + specprefill=True, + specprefill_keep_pct=0.2, + ) + + # Accumulator returns the last GenerationOutput's fields + assert output.text == "partial final" + assert output.tokens == [1, 2, 3] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 3 + assert output.finish_reason == "stop" + assert output.finished is True + + # Per-request SpecPrefill overrides reach stream_generate + assert captured_kwargs.get("prompt") == "say hi" + assert captured_kwargs.get("max_tokens") == 16 + assert captured_kwargs.get("specprefill") is True + assert captured_kwargs.get("specprefill_keep_pct") == 0.2 + + @pytest.mark.asyncio + async def test_generate_empty_stream_returns_safe_default(self): + """If stream_generate yields nothing, generate() returns an empty + stop-reason GenerationOutput rather than raising. + """ + from vllm_mlx.engine.simple import SimpleEngine + + async def empty_stream_generate(**kwargs): + return + yield # unreachable; makes this a generator + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine.stream_generate = empty_stream_generate # type: ignore[method-assign] + + output = await engine.generate(prompt="anything", max_tokens=5) + + assert output.text == "" + assert output.finish_reason == "stop" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 39cfa849d..924d28e45 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -256,13 +256,27 @@ async def generate( """ Generate a complete response (non-streaming). + Thin accumulator over stream_generate(). stream_generate() is the + only code path that consumes per-request SpecPrefill overrides + (`specprefill`, `specprefill_keep_pct`) and routes through + _stream_generate_specprefill() when engaged. The prior direct + self._model.generate() path silently dropped those overrides for + non-streaming /v1/completions callers, so extra_body.specprefill + was advertised by the server but had no effect on this route. + + By iterating stream_generate() and returning the last + GenerationOutput, non-streaming clients get the same SpecPrefill + engagement, accurate prompt_tokens reporting, and per-request + override support as streaming clients. + Args: prompt: Input text max_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Top-p sampling stop: Stop sequences - **kwargs: Additional model-specific parameters + **kwargs: Additional parameters forwarded to stream_generate, + including per-request `specprefill` / `specprefill_keep_pct` Returns: GenerationOutput with complete text @@ -270,27 +284,28 @@ async def generate( if not self._loaded: await self.start() - output = await self._run_blocking_serialized( - self._model.generate, + last_output: GenerationOutput | None = None + async for output in self.stream_generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stop=stop, **kwargs, - ) + ): + last_output = output - # Clean output text - text = clean_output_text(output.text) + if last_output is None: + return GenerationOutput(text="", finish_reason="stop") + text = clean_output_text(last_output.text) return GenerationOutput( text=text, - tokens=getattr(output, "tokens", []), - prompt_tokens=getattr(output, "prompt_tokens", 0), - completion_tokens=getattr( - output, "completion_tokens", len(getattr(output, "tokens", [])) - ), - finish_reason=output.finish_reason, + tokens=list(last_output.tokens), + prompt_tokens=last_output.prompt_tokens, + completion_tokens=last_output.completion_tokens, + finish_reason=last_output.finish_reason, + finished=True, ) async def stream_generate(