Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
656 changes: 656 additions & 0 deletions tests/v1/e2e/test_streaming_input.py

Large diffs are not rendered by default.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -93,30 +93,6 @@ async def mock_add_request(*args, **kwargs):
assert outputs[1].finished is True


@pytest.fixture
def mock_async_llm_streaming():
"""Create a mock AsyncLLM for generate with async generator."""
llm = MagicMock(spec=AsyncLLM)

llm.vllm_config = MagicMock()
llm.vllm_config.cache_config.kv_sharing_fast_prefill = False
llm.model_config = MagicMock()
llm.model_config.max_model_len = 2048
llm.log_requests = False
llm.errored = False
llm._pause_cond = asyncio.Condition()
llm._paused = False

llm._run_output_handler = MagicMock()
llm.abort = AsyncMock()

# Bind the real methods
llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM)
llm._add_streaming_request = AsyncLLM._add_streaming_request.__get__(llm, AsyncLLM)

return llm


def make_output(request_id: str, finished: bool) -> RequestOutput:
"""Helper to create a RequestOutput."""
return RequestOutput(
Expand All @@ -130,42 +106,66 @@ def make_output(request_id: str, finished: bool) -> RequestOutput:


@pytest.mark.asyncio
async def test_generate_with_async_generator(mock_async_llm_streaming):
"""Test generate with an async input generator."""
async def test_generate_with_async_generator():
"""Test generate with an async input generator.

With the new streaming input API, completion is signaled by finishing
the input generator (not via a resumable flag). Each input chunk
produces intermediate outputs, and the final output has finished=True.
"""
request_id = "test"
sampling_params = SamplingParams(max_tokens=10)

segment_count = 0
shared_queue = RequestOutputCollector(RequestOutputKind.FINAL_ONLY, request_id)

async def mock_add_request(*args, **kwargs):
nonlocal segment_count
segment_count += 1
current_segment = segment_count
llm = MagicMock(spec=AsyncLLM)
llm.vllm_config = MagicMock()
llm.vllm_config.cache_config.kv_sharing_fast_prefill = False
llm.model_config = MagicMock()
llm.model_config.max_model_len = 2048
llm.log_requests = False
llm.errored = False
llm._pause_cond = asyncio.Condition()
llm._paused = False
llm._run_output_handler = MagicMock()
llm.abort = AsyncMock()

# Stagger outputs to prevent aggregation in RequestOutputCollector
async def produce_output():
await asyncio.sleep(current_segment * 0.05)
shared_queue.put(make_output(request_id, finished=True))
# Bind the real generate method
llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM)

asyncio.create_task(produce_output())
return shared_queue
# Track inputs processed
inputs_received = []
queue = RequestOutputCollector(RequestOutputKind.DELTA, request_id)

async def mock_add_request(req_id, prompt, params, *args, **kwargs):
# When prompt is an AsyncGenerator, process streaming inputs
if isinstance(prompt, AsyncGenerator):
# Process inputs in background, produce outputs
async def handle_stream():
async for input_chunk in prompt:
inputs_received.append(input_chunk.prompt)
# Each input produces an intermediate output
queue.put(make_output(req_id, finished=False))
await asyncio.sleep(0.01)
# Final output when stream ends
queue.put(make_output(req_id, finished=True))

asyncio.create_task(handle_stream())
return queue
return queue

mock_async_llm_streaming.add_request = mock_add_request
llm.add_request = mock_add_request

async def input_generator() -> AsyncGenerator[StreamingInput, None]:
yield StreamingInput(
prompt="Hello", sampling_params=sampling_params, resumable=True
)
yield StreamingInput(
prompt=" world", sampling_params=sampling_params, resumable=False
)
yield StreamingInput(prompt="Hello", sampling_params=sampling_params)
yield StreamingInput(prompt=" world", sampling_params=sampling_params)

outputs = []
async for output in mock_async_llm_streaming.generate(
input_generator(), None, request_id
):
async for output in llm.generate(input_generator(), sampling_params, request_id):
outputs.append(output)

assert len(outputs) == 2
assert segment_count == 2
# Two intermediate outputs + one final output
assert len(outputs) == 3
assert outputs[0].finished is False
assert outputs[1].finished is False
assert outputs[2].finished is True
# Both inputs were processed
assert inputs_received == ["Hello", " world"]
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def create_scheduler() -> Scheduler:
vllm_config.model_config.skip_tokenizer_init = True
vllm_config.model_config.is_multimodal_model = False
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.enable_return_routed_experts = False
vllm_config.cache_config = MagicMock()
vllm_config.cache_config.num_gpu_blocks = 1000
vllm_config.cache_config.enable_prefix_caching = False
Expand All @@ -63,7 +64,10 @@ def create_scheduler() -> Scheduler:
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(16, 1, 1, torch.float32, False)
["layer"],
FullAttentionSpec(
block_size=16, num_kv_heads=1, head_size=1, dtype=torch.float32
),
)
],
)
Expand Down Expand Up @@ -118,27 +122,30 @@ def test_update_request_as_session_max_token(self):
new_request.sampling_params = SamplingParams(max_tokens=10)
new_request.max_tokens = 10 # Additional max_tokens from new request

session.streaming_queue.append(StreamingUpdate.from_request(new_request))
scheduler._update_request_as_session(session)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)

assert session.sampling_params.max_tokens == 10
assert session.max_tokens == 20 # 10 + 10
# _update_request_as_session clears output tokens first, so
# max_tokens = num_output_tokens (0) + update.max_tokens (10) = 10
assert session.max_tokens == 10

session.num_computed_tokens = len(session.prompt_token_ids)

# only generated additional 5
session._output_token_ids = [1] * 15
# Simulate generating 5 more output tokens
session._output_token_ids = [1] * 5
new_request2 = DummyRequest(
request_id="session",
prompt_token_ids=[7, 8, 9],
)
new_request2.sampling_params = SamplingParams(max_tokens=10)
new_request2.max_tokens = 10
session.streaming_queue.append(StreamingUpdate.from_request(new_request2))
scheduler._update_request_as_session(session)
update2 = StreamingUpdate.from_request(new_request2)
scheduler._update_request_as_session(session, update2)

assert session.sampling_params.max_tokens == 10
assert session.max_tokens == 25 # 15 + 10
# Again, output tokens are cleared first, so max_tokens = 0 + 10 = 10
assert session.max_tokens == 10

def test_update_request_as_session(self):
scheduler = create_scheduler()
Expand All @@ -155,8 +162,8 @@ def test_update_request_as_session(self):
)
new_request.sampling_params = SamplingParams(max_tokens=10)

session.streaming_queue.append(StreamingUpdate.from_request(new_request))
scheduler._update_request_as_session(session)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)

assert session.prompt_token_ids == [1, 2, 3, 4, 5, 6]
assert session._all_token_ids == [1, 2, 3, 4, 5, 6]
Expand Down Expand Up @@ -190,15 +197,22 @@ def test_update_request_as_session_with_multimodal(self):
prompt_token_ids=[4, 5, 6, 7],
mm_features=[mm_feature],
)
session.streaming_queue.append(StreamingUpdate.from_request(new_request))
scheduler._update_request_as_session(session)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)

assert len(session.mm_features) == 2
assert session.mm_features[0].mm_position.offset == 1
# 2 + len([1, 2, 3])
assert session.mm_features[1].mm_position.offset == 5

def test_process_streaming_requests_with_finish_session(self):
"""Test that a non-resumable request signals stream completion.

With the new streaming API, completion is signaled by closing/finishing
the input generator. When a non-resumable request is added to a session
in WAITING_FOR_STREAMING_REQ state, the session is finished immediately
with FINISHED_ABORTED status.
"""
scheduler = create_scheduler()

session = DummyRequest(
Expand All @@ -210,37 +224,33 @@ def test_process_streaming_requests_with_finish_session(self):
session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
session.num_computed_tokens = len(session.prompt_token_ids)

# A non-resumable request signals stream completion
close_request = DummyRequest(
request_id="session",
prompt_token_ids=[0],
resumable=False,
max_tokens=1,
)
scheduler.add_request(close_request)
assert close_request.status == RequestStatus.WAITING
assert len(session.streaming_queue) == 1

sout = scheduler.schedule()
mro = ModelRunnerOutput(
req_ids=[session.request_id],
req_id_to_index={session.request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={session.request_id: None},
pooler_output=None,
)
out = scheduler.update_from_output(sout, mro)
assert session.status == RequestStatus.FINISHED_LENGTH_CAPPED
assert len(out) == 1
assert out[0].outputs[0].request_id == session.request_id
assert out[0].outputs[0].resumable is False
# The session should be immediately finished (stream completed)
assert session.status == RequestStatus.FINISHED_ABORTED
# The session should be removed from the scheduler
assert session.request_id not in scheduler.requests

def test_streaming_request_session_update(self):
"""Test that a resumable request updates a waiting session directly.

When a session is in WAITING_FOR_STREAMING_REQ state and a new resumable
request arrives, the update is applied directly via _update_request_as_session,
not queued.
"""
scheduler = create_scheduler()

session = DummyRequest(
request_id="session",
prompt_token_ids=[1, 2, 3],
resumable=True,
)
scheduler.add_request(session)
session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
Expand All @@ -253,13 +263,16 @@ def test_streaming_request_session_update(self):
)

scheduler.add_request(next_request)
assert next_request.status == RequestStatus.WAITING
assert len(session.streaming_queue) == 1

# With the new behavior, when session is in WAITING_FOR_STREAMING_REQ,
# the update is applied directly (not queued), and session status
# becomes WAITING
assert session.status == RequestStatus.WAITING
assert session.prompt_token_ids == [1, 2, 3, 4, 5]

_ = scheduler.schedule()

assert session.status == RequestStatus.RUNNING
assert session.prompt_token_ids == [1, 2, 3, 4, 5]

def test_update_request_as_session_with_output_tokens(self):
scheduler = create_scheduler()
Expand All @@ -280,14 +293,19 @@ def test_update_request_as_session_with_output_tokens(self):
prompt_token_ids=[4, 5],
)

session.streaming_queue.append(StreamingUpdate.from_request(new_request))
scheduler._update_request_as_session(session)
update = StreamingUpdate.from_request(new_request)
scheduler._update_request_as_session(session, update)

# Verify the last output token (11) was removed, and new prompt tokens added
# _update_request_as_session keeps computed output tokens (they become
# part of the prompt) and only discards the final uncomputed sampled
# token. Computed output token 10 is kept, uncomputed token 11 is
# discarded.
assert session._all_token_ids == [1, 2, 3, 10, 4, 5]
assert session.prompt_token_ids == [1, 2, 3, 4, 5]
# Verify output tokens list is unchanged (only removed from _all_token_ids)
assert session._output_token_ids == [10, 11]
assert session.prompt_token_ids == [1, 2, 3, 10, 4, 5]
# Output tokens list is cleared
assert session._output_token_ids == []
# num_computed_tokens is unchanged (KV cache still valid for computed
# tokens)
assert session.num_computed_tokens == 4
# Verify that the next schedule will only process the new prompt tokens
# num_new_tokens = num_tokens - num_computed_tokens = 6 - 4 = 2
Expand Down Expand Up @@ -369,9 +387,13 @@ def test_streaming_e2e_lifecycle(self):

# Step 3: Simulate model runner caching the prompt_token_ids
# This simulates gpu_model_runner.py:706-720 CachedRequestState creation
# The model runner makes a copy of prompt_token_ids when creating
# CachedRequestState
cached_state_cycle1 = {
"req_id": session.request_id,
"prompt_token_ids": new_req_data_cycle1.prompt_token_ids, # Must be a copy!
"prompt_token_ids": list(
new_req_data_cycle1.prompt_token_ids
), # Explicit copy
"output_token_ids": [],
"num_computed_tokens": 0,
}
Expand Down Expand Up @@ -495,31 +517,39 @@ def test_streaming_e2e_lifecycle(self):
prompt_token_ids=[4, 5],
)
scheduler.add_request(new_request)
assert new_request.status == RequestStatus.WAITING
assert len(session.streaming_queue) == 1

# Step 13: Scheduler merges new request into session and schedules
# With the new streaming API, when session is in WAITING_FOR_STREAMING_REQ,
# the update is applied directly via _update_request_as_session (not queued).
# The session status becomes WAITING after the update is applied.
assert session.status == RequestStatus.WAITING

# Step 13: Scheduler schedules the updated session
scheduler_output_cycle3 = scheduler.schedule()

# Verify scheduler created NewRequestData with merged _all_token_ids
# Verify scheduler created NewRequestData with merged prompt_token_ids
assert len(scheduler_output_cycle3.scheduled_new_reqs) == 1
assert (
scheduler_output_cycle3.scheduled_new_reqs[0].prompt_token_ids
== session._all_token_ids
== session.prompt_token_ids
)
assert (
scheduler_output_cycle3.num_scheduled_tokens[session.request_id] == 2
) # Only new tokens [4, 5]
# STOP_TOKEN removed from _all_token_ids
# Computed output tokens are kept (become part of prompt), only the
# final uncomputed sampled token (STOP_TOKEN) is discarded
assert session._all_token_ids == [1, 2, 3, 10, 4, 5]
assert session.prompt_token_ids == [1, 2, 3, 4, 5] # Only prompts
assert session._output_token_ids == [10, STOP_TOKEN]
assert session.prompt_token_ids == [1, 2, 3, 10, 4, 5] # Includes kept output
assert session._output_token_ids == [] # Output tokens are cleared

# Step 14: Model runner caches NEW prompt_token_ids reference
# The model runner makes a copy of prompt_token_ids when creating
# CachedRequestState
new_req_data_cycle3 = scheduler_output_cycle3.scheduled_new_reqs[0]
cached_state_cycle3 = {
"req_id": session.request_id,
"prompt_token_ids": new_req_data_cycle3.prompt_token_ids,
"prompt_token_ids": list(
new_req_data_cycle3.prompt_token_ids
), # Explicit copy
"output_token_ids": [],
"num_computed_tokens": session.num_computed_tokens,
}
Expand Down
10 changes: 10 additions & 0 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ def __repr__(self) -> str:
)


# Sentinel to indicate request is finished, used with streaming inputs.
STREAM_FINISHED = RequestOutput(
request_id="",
prompt=None,
prompt_token_ids=None,
prompt_logprobs=None,
outputs=[],
finished=True,
)

_O = TypeVar("_O", default=PoolingOutput)


Expand Down
Loading