Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/test_simple_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ def generate_side_effect(**kwargs):
return result

model.generate = MagicMock(side_effect=generate_side_effect)

# stream_generate tracks concurrency the same way so tests that
# exercise SimpleEngine.generate() (which is now an accumulator
# over stream_generate) see the same serialization behavior.
def stream_generate_side_effect(**kwargs):
model._concurrent_count += 1
model._max_concurrent = max(model._max_concurrent, model._concurrent_count)
import time

time.sleep(0.05)
model._concurrent_count -= 1
chunk = MagicMock()
chunk.text = "test response"
chunk.tokens = [1, 2, 3]
chunk.finished = True
chunk.finish_reason = "stop"
chunk.prompt_tokens = 3
chunk.completion_tokens = 3
yield chunk

model.stream_generate = MagicMock(side_effect=stream_generate_side_effect)
return model

@pytest.fixture
Expand Down Expand Up @@ -213,3 +234,86 @@ async def test_requests_complete_in_order(self, mock_model):
assert len(results) == 3
for result in results:
assert result.text == "test response"

@pytest.mark.asyncio
async def test_generate_accumulates_over_stream_generate(self):
"""generate() should iterate stream_generate() and return the last
yielded GenerationOutput, forwarding per-request kwargs (including
SpecPrefill overrides) through so they reach _stream_generate_specprefill.
"""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.engine.simple import SimpleEngine

captured_kwargs = {}

async def fake_stream_generate(**kwargs):
captured_kwargs.update(kwargs)
# First chunk: mid-generation
yield GenerationOutput(
text="partial",
new_text="partial",
tokens=[1, 2],
prompt_tokens=11,
completion_tokens=2,
finished=False,
finish_reason=None,
)
# Final chunk: finished
yield GenerationOutput(
text="partial final",
new_text=" final",
tokens=[1, 2, 3],
prompt_tokens=11,
completion_tokens=3,
finished=True,
finish_reason="stop",
)

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._loaded = True
engine.stream_generate = fake_stream_generate # type: ignore[method-assign]

output = await engine.generate(
prompt="say hi",
max_tokens=16,
temperature=0.6,
top_p=0.95,
specprefill=True,
specprefill_keep_pct=0.2,
)

# Accumulator returns the last GenerationOutput's fields
assert output.text == "partial final"
assert output.tokens == [1, 2, 3]
assert output.prompt_tokens == 11
assert output.completion_tokens == 3
assert output.finish_reason == "stop"
assert output.finished is True

# Per-request SpecPrefill overrides reach stream_generate
assert captured_kwargs.get("prompt") == "say hi"
assert captured_kwargs.get("max_tokens") == 16
assert captured_kwargs.get("specprefill") is True
assert captured_kwargs.get("specprefill_keep_pct") == 0.2

@pytest.mark.asyncio
async def test_generate_empty_stream_returns_safe_default(self):
"""If stream_generate yields nothing, generate() returns an empty
stop-reason GenerationOutput rather than raising.
"""
from vllm_mlx.engine.simple import SimpleEngine

async def empty_stream_generate(**kwargs):
return
yield # unreachable; makes this a generator

with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
engine = SimpleEngine("test-model")
engine._loaded = True
engine.stream_generate = empty_stream_generate # type: ignore[method-assign]

output = await engine.generate(prompt="anything", max_tokens=5)

assert output.text == ""
assert output.finish_reason == "stop"
39 changes: 27 additions & 12 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,41 +256,56 @@ async def generate(
"""
Generate a complete response (non-streaming).

Thin accumulator over stream_generate(). stream_generate() is the
only code path that consumes per-request SpecPrefill overrides
(`specprefill`, `specprefill_keep_pct`) and routes through
_stream_generate_specprefill() when engaged. The prior direct
self._model.generate() path silently dropped those overrides for
non-streaming /v1/completions callers, so extra_body.specprefill
was advertised by the server but had no effect on this route.

By iterating stream_generate() and returning the last
GenerationOutput, non-streaming clients get the same SpecPrefill
engagement, accurate prompt_tokens reporting, and per-request
override support as streaming clients.

Args:
prompt: Input text
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
stop: Stop sequences
**kwargs: Additional model-specific parameters
**kwargs: Additional parameters forwarded to stream_generate,
including per-request `specprefill` / `specprefill_keep_pct`

Returns:
GenerationOutput with complete text
"""
if not self._loaded:
await self.start()

output = await self._run_blocking_serialized(
self._model.generate,
last_output: GenerationOutput | None = None
async for output in self.stream_generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
**kwargs,
)
):
last_output = output

# Clean output text
text = clean_output_text(output.text)
if last_output is None:
return GenerationOutput(text="", finish_reason="stop")

text = clean_output_text(last_output.text)
return GenerationOutput(
text=text,
tokens=getattr(output, "tokens", []),
prompt_tokens=getattr(output, "prompt_tokens", 0),
completion_tokens=getattr(
output, "completion_tokens", len(getattr(output, "tokens", []))
),
finish_reason=output.finish_reason,
tokens=list(last_output.tokens),
prompt_tokens=last_output.prompt_tokens,
completion_tokens=last_output.completion_tokens,
finish_reason=last_output.finish_reason,
finished=True,
)

async def stream_generate(
Expand Down
Loading