diff --git a/tests/v1/e2e/test_streaming_input.py b/tests/v1/e2e/test_streaming_input.py new file mode 100644 index 000000000000..40bb30d9a95a --- /dev/null +++ b/tests/v1/e2e/test_streaming_input.py @@ -0,0 +1,656 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +End-to-end tests for the streaming input feature in AsyncLLM. + +These tests verify that: +1. Streaming inputs work correctly with bunched inputs (queued) +2. Streaming inputs work correctly with spaced out inputs +3. Outputs are equivalent whether inputs are bunched or spaced +4. Cancelling the output stream correctly aborts the session +5. Closing the input stream correctly signals completion +6. Queued inputs are cancelled when the session is aborted +""" + +import asyncio +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio + +from vllm import SamplingParams +from vllm.outputs import RequestOutput +from vllm.platforms import current_platform +from vllm.sampling_params import RequestOutputKind +from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput + +if not current_platform.is_cuda(): + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) + +# Use a small model that doesn't require authentication for fast tests +MODEL = "facebook/opt-125m" + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def engine(): + """Create an AsyncLLM engine for the test. + + Note: Using function scope because pytest_asyncio creates a new event loop + for each test, and the output_handler task gets cancelled between tests + with module scope. + """ + from vllm.engine.arg_utils import AsyncEngineArgs + + engine_args = AsyncEngineArgs( + model=MODEL, enforce_eager=True, gpu_memory_utilization=0.7 + ) + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(engine_args) + try: + yield engine + finally: + engine.shutdown() + await asyncio.sleep(0.1) + + +def get_sampling_params(max_tokens: int = 20) -> SamplingParams: + """Create sampling params for streaming input tests.""" + return SamplingParams( + max_tokens=max_tokens, + ignore_eos=True, + output_kind=RequestOutputKind.DELTA, + temperature=0.0, # Deterministic for reproducibility + ) + + +async def collect_outputs( + output_gen: AsyncGenerator[RequestOutput, None], +) -> tuple[list[RequestOutput], str]: + """Collect all outputs from a generate call, return outputs and full text.""" + outputs: list[RequestOutput] = [] + full_text = "" + async for output in output_gen: + outputs.append(output) + if output.outputs and output.outputs[0].text: + full_text += output.outputs[0].text + return outputs, full_text + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_bunched(engine: AsyncLLM): + """Test streaming input where all inputs are sent at once (bunched/queued). + + This tests the case where multiple inputs arrive before any completes. + The inputs should be queued and processed in sequence. + """ + request_id = "test_bunched" + sampling_params = get_sampling_params(max_tokens=10) + + # Create an input generator that yields all inputs quickly + async def bunched_input_generator() -> AsyncGenerator[StreamingInput, None]: + # Send multiple inputs rapidly - they should be queued + yield StreamingInput(prompt="Hello, my name is") + yield StreamingInput(prompt=" Alice and I like") + yield StreamingInput(prompt=" to code in Python") + + outputs, full_text = await collect_outputs( + engine.generate( + bunched_input_generator(), + sampling_params, + request_id, + ) + ) + + # Verify we got outputs + assert len(outputs) > 0, "Should have received outputs" + + # Verify the final output is marked as finished + assert outputs[-1].finished, "Last output should be marked as finished" + + # Verify intermediate outputs are not marked as finished + for output in outputs[:-1]: + assert not output.finished, "Intermediate outputs should not be finished" + + # Verify we generated some text + assert len(full_text) > 0, "Should have generated text" + print(f"Bunched test generated: {full_text}") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_spaced(engine: AsyncLLM): + """Test streaming input where inputs are spaced out. + + This tests the case where each input completes processing before the + next one is sent. Each chunk should be prefilled, generate tokens, + then the next chunk should be processed. + """ + request_id = "test_spaced" + sampling_params = get_sampling_params(max_tokens=10) + + # Track when each input is sent + input_times: list[float] = [] + outputs_per_chunk: list[int] = [0, 0, 0] + current_chunk = 0 + + async def spaced_input_generator() -> AsyncGenerator[StreamingInput, None]: + nonlocal current_chunk + import time + + # First input + input_times.append(time.time()) + yield StreamingInput(prompt="Hello, my name is") + current_chunk = 0 + + # Wait for some outputs to be generated + await asyncio.sleep(0.5) + + # Second input + input_times.append(time.time()) + current_chunk = 1 + yield StreamingInput(prompt=" Alice and I like") + + # Wait for some outputs + await asyncio.sleep(0.5) + + # Third input + input_times.append(time.time()) + current_chunk = 2 + yield StreamingInput(prompt=" to code in Python") + + outputs: list[RequestOutput] = [] + full_text = "" + + async for output in engine.generate( + spaced_input_generator(), + sampling_params, + request_id, + ): + outputs.append(output) + if output.outputs and output.outputs[0].text: + full_text += output.outputs[0].text + outputs_per_chunk[current_chunk] += 1 + + # Verify we got outputs + assert len(outputs) > 0, "Should have received outputs" + + # Verify the final output is marked as finished + assert outputs[-1].finished, "Last output should be marked as finished" + + # Verify we received outputs from multiple chunks + # (with spaced inputs, we should see outputs distributed across chunks) + chunks_with_outputs = sum(1 for c in outputs_per_chunk if c > 0) + assert chunks_with_outputs >= 1, "Should have outputs from at least one chunk" + + print(f"Spaced test generated: {full_text}") + print(f"Outputs per chunk: {outputs_per_chunk}") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_output_equivalence(engine: AsyncLLM): + """Test that bunched and spaced inputs produce equivalent outputs. + + When the same prompts are provided either bunched or spaced, + the final concatenated output should be the same (with deterministic + sampling). + """ + prompts = ["Hello, my name is", " Bob and I work", " at Anthropic"] + sampling_params = get_sampling_params(max_tokens=15) + + # Test bunched inputs + async def bunched_gen() -> AsyncGenerator[StreamingInput, None]: + for prompt in prompts: + yield StreamingInput(prompt=prompt) + + _, bunched_text = await collect_outputs( + engine.generate(bunched_gen(), sampling_params, "equiv_bunched") + ) + + # Test spaced inputs (same prompts, but with delays) + async def spaced_gen() -> AsyncGenerator[StreamingInput, None]: + for prompt in prompts: + yield StreamingInput(prompt=prompt) + await asyncio.sleep(0.3) + + _, spaced_text = await collect_outputs( + engine.generate(spaced_gen(), sampling_params, "equiv_spaced") + ) + + # Both should produce the same output since we use temperature=0 + assert bunched_text == spaced_text, ( + f"Bunched and spaced should produce same output.\n" + f"Bunched: {bunched_text!r}\n" + f"Spaced: {spaced_text!r}" + ) + + print(f"Equivalence test passed. Generated: {bunched_text}") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_cancel_output_stream(engine: AsyncLLM): + """Test that cancelling the output stream aborts the entire session. + + When the consumer cancels iteration over the output generator, + the session should be aborted including any queued inputs. + """ + request_id = "test_cancel_output" + sampling_params = get_sampling_params(max_tokens=1000) + + input_completed = asyncio.Event() + input_task_cancelled = False + + async def slow_input_generator() -> AsyncGenerator[StreamingInput, None]: + nonlocal input_task_cancelled + try: + yield StreamingInput(prompt="Tell me a very long story about") + yield StreamingInput(prompt=" a dragon and a knight") + + # This should be cancelled before we get here + await asyncio.sleep(10) + yield StreamingInput(prompt=" who become friends") + input_completed.set() + except asyncio.CancelledError: + input_task_cancelled = True + raise + + outputs_received = 0 + output_gen = engine.generate(slow_input_generator(), sampling_params, request_id) + + # Collect a few outputs then cancel + try: + async for output in output_gen: + outputs_received += 1 + if outputs_received >= 5: + # Cancel by breaking out of the loop (generator will be GC'd) + break + finally: + # Explicitly close the generator to ensure cleanup + await output_gen.aclose() + + # Give time for cleanup + await asyncio.sleep(0.5) + + # Verify we got some outputs before cancelling + assert outputs_received >= 5, "Should have received outputs before cancel" + + # Verify the input task was cancelled + assert input_task_cancelled, "Input task should have been cancelled" + + # Verify the session is properly cleaned up + assert not engine.output_processor.has_unfinished_requests(), ( + "Should have no unfinished requests after cancel" + ) + + print(f"Cancel test passed. Received {outputs_received} outputs before cancel") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_close_signals_completion(engine: AsyncLLM): + """Test that closing the input stream signals completion. + + When the input generator finishes (naturally or via return), + the session should complete with finished=True on the last output. + """ + request_id = "test_close_completion" + sampling_params = get_sampling_params(max_tokens=15) + + input_generator_finished = False + + async def limited_input_generator() -> AsyncGenerator[StreamingInput, None]: + nonlocal input_generator_finished + yield StreamingInput(prompt="What is 2 + 2? The answer is") + # Generator finishes naturally here + input_generator_finished = True + + outputs, _ = await collect_outputs( + engine.generate(limited_input_generator(), sampling_params, request_id) + ) + + # Verify the input generator completed + assert input_generator_finished, "Input generator should have finished" + + # Verify we got a finished output + assert len(outputs) > 0, "Should have received outputs" + assert outputs[-1].finished, "Last output should be marked as finished" + + # Verify the session is cleaned up + assert not engine.output_processor.has_unfinished_requests(), ( + "Should have no unfinished requests" + ) + + print("Close completion test passed") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_abort_queued_inputs(engine: AsyncLLM): + """Test that aborting the session cancels queued inputs. + + When multiple inputs are queued and the session is aborted, + all pending inputs should be cancelled. + """ + request_id = "test_abort_queued" + # Use large max_tokens to ensure we have time to queue inputs + sampling_params = get_sampling_params(max_tokens=2000) + + inputs_sent = 0 + input_cancelled = False + + async def many_inputs_generator() -> AsyncGenerator[StreamingInput, None]: + nonlocal inputs_sent, input_cancelled + try: + # Send several inputs to fill the queue + for i in range(10): + yield StreamingInput(prompt=f" Part {i}: Tell me about the number {i}.") + inputs_sent += 1 + # Small delay to interleave with output processing + await asyncio.sleep(0.05) + except asyncio.CancelledError: + input_cancelled = True + raise + + outputs_received = 0 + output_gen = engine.generate(many_inputs_generator(), sampling_params, request_id) + + try: + async for output in output_gen: + outputs_received += 1 + # Cancel after receiving some outputs + if outputs_received >= 10: + break + finally: + await output_gen.aclose() + + # Give time for cleanup + await asyncio.sleep(0.5) + + # Verify we received some outputs + assert outputs_received >= 10, "Should have received outputs before abort" + + # Verify the input generator was cancelled OR finished naturally + # (it might finish naturally if all inputs were sent before cancel) + assert input_cancelled or inputs_sent == 10, ( + f"Input generator should have been cancelled or completed. " + f"cancelled={input_cancelled}, inputs_sent={inputs_sent}" + ) + + # Verify the session is cleaned up + assert not engine.output_processor.has_unfinished_requests(), ( + "Should have no unfinished requests after abort" + ) + + print( + f"Abort queued test passed. Sent {inputs_sent} inputs, " + f"received {outputs_received} outputs" + ) + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_error_propagation(engine: AsyncLLM): + """Test that errors in the input generator are propagated to the caller.""" + request_id = "test_error_propagation" + sampling_params = get_sampling_params(max_tokens=20) + + class InputError(Exception): + pass + + async def error_input_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="Start with this") + await asyncio.sleep(0.1) + raise InputError("Simulated input error") + + # Note: The current implementation catches exceptions and puts them + # in the queue, so we should get the error when iterating outputs + with pytest.raises(InputError, match="Simulated input error"): + async for _ in engine.generate( + error_input_generator(), sampling_params, request_id + ): + pass + + # Give time for cleanup + await asyncio.sleep(0.3) + + # Verify the session is cleaned up + assert not engine.output_processor.has_unfinished_requests(), ( + "Should have no unfinished requests after error" + ) + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_multiple_concurrent_sessions(engine: AsyncLLM): + """Test multiple concurrent streaming input sessions. + + Multiple streaming sessions should be able to run concurrently + without interfering with each other. + """ + num_sessions = 3 + results: list[tuple[str, str]] = [] + + async def run_session(session_id: int) -> tuple[str, str]: + request_id = f"test_concurrent_{session_id}" + sampling_params = get_sampling_params(max_tokens=10) + + prompts = [f"Session {session_id}: Hello", f" world from session {session_id}"] + + async def input_gen() -> AsyncGenerator[StreamingInput, None]: + for prompt in prompts: + yield StreamingInput(prompt=prompt) + await asyncio.sleep(0.1) + + _, text = await collect_outputs( + engine.generate(input_gen(), sampling_params, request_id) + ) + return request_id, text + + # Run sessions concurrently + tasks = [asyncio.create_task(run_session(i)) for i in range(num_sessions)] + results = await asyncio.gather(*tasks) + + # Verify all sessions completed + assert len(results) == num_sessions + + for request_id, text in results: + assert len(text) > 0, f"Session {request_id} should have generated text" + print(f"{request_id}: {text}") + + # Verify cleanup + assert not engine.output_processor.has_unfinished_requests() + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_per_chunk_sampling_params(engine: AsyncLLM): + """Test that per-chunk sampling params are respected. + + Each StreamingInput can have its own sampling_params. + """ + request_id = "test_per_chunk_params" + base_params = get_sampling_params(max_tokens=10) + + async def variable_params_generator() -> AsyncGenerator[StreamingInput, None]: + # First chunk with base params + yield StreamingInput(prompt="Count to five:", sampling_params=base_params) + + # Second chunk with different max_tokens + chunk_params = get_sampling_params(max_tokens=5) + yield StreamingInput( + prompt=" Now count backwards:", sampling_params=chunk_params + ) + + outputs, full_text = await collect_outputs( + engine.generate(variable_params_generator(), base_params, request_id) + ) + + assert len(outputs) > 0, "Should have received outputs" + assert outputs[-1].finished, "Last output should be finished" + assert len(full_text) > 0, "Should have generated text" + + print(f"Per-chunk params test generated: {full_text}") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_empty_generator(engine: AsyncLLM): + """Test behavior when the input generator yields nothing. + + An empty generator should still produce a finished output. + """ + request_id = "test_empty_generator" + sampling_params = get_sampling_params(max_tokens=10) + + async def empty_generator() -> AsyncGenerator[StreamingInput, None]: + # Don't yield anything + return + yield # Make it a generator + + outputs: list[RequestOutput] = [] + async for output in engine.generate(empty_generator(), sampling_params, request_id): + outputs.append(output) + + # Should still get a finished marker + assert len(outputs) >= 1, "Should receive at least one output" + assert outputs[-1].finished, "Should have a finished output" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_single_chunk(engine: AsyncLLM): + """Test streaming input with a single chunk. + + This is effectively the same as a regular non-streaming request, + but using the streaming input API. + """ + request_id = "test_single_chunk" + sampling_params = get_sampling_params(max_tokens=15) + + async def single_chunk_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="What color is the sky? The sky is") + + outputs, full_text = await collect_outputs( + engine.generate(single_chunk_generator(), sampling_params, request_id) + ) + + assert len(outputs) > 0 + assert outputs[-1].finished + assert "blue" in full_text.lower() or len(full_text) > 0 + + print(f"Single chunk test generated: {full_text}") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_reuse_request_id(engine: AsyncLLM): + """Test that request IDs can be reused after a session completes.""" + request_id = "test_reuse_id" + sampling_params = get_sampling_params(max_tokens=5) + + # First session + async def gen1() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="First session") + + _, text1 = await collect_outputs( + engine.generate(gen1(), sampling_params, request_id) + ) + + # Second session with same ID + async def gen2() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="Second session") + + _, text2 = await collect_outputs( + engine.generate(gen2(), sampling_params, request_id) + ) + + assert len(text1) > 0 + assert len(text2) > 0 + assert not engine.output_processor.has_unfinished_requests() + + print(f"Reuse ID test: session 1: {text1}, session 2: {text2}") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_validation_errors(engine: AsyncLLM): + """Test that invalid configurations raise appropriate errors.""" + + async def dummy_generator() -> AsyncGenerator[StreamingInput, None]: + yield StreamingInput(prompt="test") + + # Test n > 1 is rejected + with pytest.raises(ValueError, match="Input streaming not currently supported"): + params_n2 = SamplingParams(max_tokens=10, n=2) + async for _ in engine.generate(dummy_generator(), params_n2, "test_n2"): + pass + + # Test FINAL_ONLY is rejected + with pytest.raises(ValueError, match="Input streaming not currently supported"): + params_final = SamplingParams( + max_tokens=10, output_kind=RequestOutputKind.FINAL_ONLY + ) + async for _ in engine.generate(dummy_generator(), params_final, "test_final"): + pass + + # Test stop strings are rejected + with pytest.raises(ValueError, match="Input streaming not currently supported"): + params_stop = SamplingParams(max_tokens=10, stop=["stop"]) + async for _ in engine.generate(dummy_generator(), params_stop, "test_stop"): + pass + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_input_delayed_generator_exit(engine: AsyncLLM): + """Test that output generator exits when input generator closes after outputs. + + This tests the case where: + 1. Multiple inputs are sent and fully processed + 2. The engine has finished + 3. The input generator doesn't exit until after the engine finishes + 4. The output generator should exit properly once the input generator exits + """ + request_id = "test_delayed_exit" + sampling_params = get_sampling_params(max_tokens=10) + + engine_finished_event = asyncio.Event() + input_generator_exited = False + finish_count = 0 + + async def delayed_exit_input_generator() -> AsyncGenerator[StreamingInput, None]: + nonlocal input_generator_exited + # Send all inputs immediately + yield StreamingInput(prompt="Hello, my name is") + yield StreamingInput(prompt=" Alice") + + # Wait until the engine has finished generating before exiting + await engine_finished_event.wait() + + # Add a small delay to ensure we're testing the "delayed exit" case + await asyncio.sleep(0.1) + input_generator_exited = True + + outputs: list[RequestOutput] = [] + full_text = "" + + async for output in engine.generate( + delayed_exit_input_generator(), sampling_params, request_id + ): + outputs.append(output) + if output.outputs and output.outputs[0].text: + full_text += output.outputs[0].text + + # Signal when the engine finishes both input chunks (each gets a finish_reason) + # Note: output.finished will be False while input stream is open + if output.outputs and output.outputs[0].finish_reason is not None: + finish_count += 1 + if finish_count == 2: + engine_finished_event.set() + + # Verify the input generator exited properly + assert input_generator_exited, ( + "Input generator should have exited after engine finished" + ) + + # Verify we got outputs + assert len(outputs) > 0, "Should have received outputs" + + # Verify we generated some text + assert len(full_text) > 0, "Should have generated text" + + # Verify the session is cleaned up + assert not engine.output_processor.has_unfinished_requests(), ( + "Should have no unfinished requests" + ) + + print(f"Delayed exit test passed. Generated: {full_text}") diff --git a/tests/v1/streaming/__init__.py b/tests/v1/streaming_input/__init__.py similarity index 100% rename from tests/v1/streaming/__init__.py rename to tests/v1/streaming_input/__init__.py diff --git a/tests/v1/streaming/test_async_llm_streaming.py b/tests/v1/streaming_input/test_async_llm_streaming.py similarity index 67% rename from tests/v1/streaming/test_async_llm_streaming.py rename to tests/v1/streaming_input/test_async_llm_streaming.py index bc097c42aeb0..913576f70006 100644 --- a/tests/v1/streaming/test_async_llm_streaming.py +++ b/tests/v1/streaming_input/test_async_llm_streaming.py @@ -93,30 +93,6 @@ async def mock_add_request(*args, **kwargs): assert outputs[1].finished is True -@pytest.fixture -def mock_async_llm_streaming(): - """Create a mock AsyncLLM for generate with async generator.""" - llm = MagicMock(spec=AsyncLLM) - - llm.vllm_config = MagicMock() - llm.vllm_config.cache_config.kv_sharing_fast_prefill = False - llm.model_config = MagicMock() - llm.model_config.max_model_len = 2048 - llm.log_requests = False - llm.errored = False - llm._pause_cond = asyncio.Condition() - llm._paused = False - - llm._run_output_handler = MagicMock() - llm.abort = AsyncMock() - - # Bind the real methods - llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM) - llm._add_streaming_request = AsyncLLM._add_streaming_request.__get__(llm, AsyncLLM) - - return llm - - def make_output(request_id: str, finished: bool) -> RequestOutput: """Helper to create a RequestOutput.""" return RequestOutput( @@ -130,42 +106,66 @@ def make_output(request_id: str, finished: bool) -> RequestOutput: @pytest.mark.asyncio -async def test_generate_with_async_generator(mock_async_llm_streaming): - """Test generate with an async input generator.""" +async def test_generate_with_async_generator(): + """Test generate with an async input generator. + + With the new streaming input API, completion is signaled by finishing + the input generator (not via a resumable flag). Each input chunk + produces intermediate outputs, and the final output has finished=True. + """ request_id = "test" sampling_params = SamplingParams(max_tokens=10) - segment_count = 0 - shared_queue = RequestOutputCollector(RequestOutputKind.FINAL_ONLY, request_id) - - async def mock_add_request(*args, **kwargs): - nonlocal segment_count - segment_count += 1 - current_segment = segment_count + llm = MagicMock(spec=AsyncLLM) + llm.vllm_config = MagicMock() + llm.vllm_config.cache_config.kv_sharing_fast_prefill = False + llm.model_config = MagicMock() + llm.model_config.max_model_len = 2048 + llm.log_requests = False + llm.errored = False + llm._pause_cond = asyncio.Condition() + llm._paused = False + llm._run_output_handler = MagicMock() + llm.abort = AsyncMock() - # Stagger outputs to prevent aggregation in RequestOutputCollector - async def produce_output(): - await asyncio.sleep(current_segment * 0.05) - shared_queue.put(make_output(request_id, finished=True)) + # Bind the real generate method + llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM) - asyncio.create_task(produce_output()) - return shared_queue + # Track inputs processed + inputs_received = [] + queue = RequestOutputCollector(RequestOutputKind.DELTA, request_id) + + async def mock_add_request(req_id, prompt, params, *args, **kwargs): + # When prompt is an AsyncGenerator, process streaming inputs + if isinstance(prompt, AsyncGenerator): + # Process inputs in background, produce outputs + async def handle_stream(): + async for input_chunk in prompt: + inputs_received.append(input_chunk.prompt) + # Each input produces an intermediate output + queue.put(make_output(req_id, finished=False)) + await asyncio.sleep(0.01) + # Final output when stream ends + queue.put(make_output(req_id, finished=True)) + + asyncio.create_task(handle_stream()) + return queue + return queue - mock_async_llm_streaming.add_request = mock_add_request + llm.add_request = mock_add_request async def input_generator() -> AsyncGenerator[StreamingInput, None]: - yield StreamingInput( - prompt="Hello", sampling_params=sampling_params, resumable=True - ) - yield StreamingInput( - prompt=" world", sampling_params=sampling_params, resumable=False - ) + yield StreamingInput(prompt="Hello", sampling_params=sampling_params) + yield StreamingInput(prompt=" world", sampling_params=sampling_params) outputs = [] - async for output in mock_async_llm_streaming.generate( - input_generator(), None, request_id - ): + async for output in llm.generate(input_generator(), sampling_params, request_id): outputs.append(output) - assert len(outputs) == 2 - assert segment_count == 2 + # Two intermediate outputs + one final output + assert len(outputs) == 3 + assert outputs[0].finished is False + assert outputs[1].finished is False + assert outputs[2].finished is True + # Both inputs were processed + assert inputs_received == ["Hello", " world"] diff --git a/tests/v1/streaming/test_gpu_model_runner_streaming.py b/tests/v1/streaming_input/test_gpu_model_runner_streaming.py similarity index 100% rename from tests/v1/streaming/test_gpu_model_runner_streaming.py rename to tests/v1/streaming_input/test_gpu_model_runner_streaming.py diff --git a/tests/v1/streaming/test_scheduler_streaming.py b/tests/v1/streaming_input/test_scheduler_streaming.py similarity index 82% rename from tests/v1/streaming/test_scheduler_streaming.py rename to tests/v1/streaming_input/test_scheduler_streaming.py index 60230307840d..0387d31c98e9 100644 --- a/tests/v1/streaming/test_scheduler_streaming.py +++ b/tests/v1/streaming_input/test_scheduler_streaming.py @@ -55,6 +55,7 @@ def create_scheduler() -> Scheduler: vllm_config.model_config.skip_tokenizer_init = True vllm_config.model_config.is_multimodal_model = False vllm_config.model_config.max_model_len = 1024 + vllm_config.model_config.enable_return_routed_experts = False vllm_config.cache_config = MagicMock() vllm_config.cache_config.num_gpu_blocks = 1000 vllm_config.cache_config.enable_prefix_caching = False @@ -63,7 +64,10 @@ def create_scheduler() -> Scheduler: kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( - ["layer"], FullAttentionSpec(16, 1, 1, torch.float32, False) + ["layer"], + FullAttentionSpec( + block_size=16, num_kv_heads=1, head_size=1, dtype=torch.float32 + ), ) ], ) @@ -118,27 +122,30 @@ def test_update_request_as_session_max_token(self): new_request.sampling_params = SamplingParams(max_tokens=10) new_request.max_tokens = 10 # Additional max_tokens from new request - session.streaming_queue.append(StreamingUpdate.from_request(new_request)) - scheduler._update_request_as_session(session) + update = StreamingUpdate.from_request(new_request) + scheduler._update_request_as_session(session, update) assert session.sampling_params.max_tokens == 10 - assert session.max_tokens == 20 # 10 + 10 + # _update_request_as_session clears output tokens first, so + # max_tokens = num_output_tokens (0) + update.max_tokens (10) = 10 + assert session.max_tokens == 10 session.num_computed_tokens = len(session.prompt_token_ids) - # only generated additional 5 - session._output_token_ids = [1] * 15 + # Simulate generating 5 more output tokens + session._output_token_ids = [1] * 5 new_request2 = DummyRequest( request_id="session", prompt_token_ids=[7, 8, 9], ) new_request2.sampling_params = SamplingParams(max_tokens=10) new_request2.max_tokens = 10 - session.streaming_queue.append(StreamingUpdate.from_request(new_request2)) - scheduler._update_request_as_session(session) + update2 = StreamingUpdate.from_request(new_request2) + scheduler._update_request_as_session(session, update2) assert session.sampling_params.max_tokens == 10 - assert session.max_tokens == 25 # 15 + 10 + # Again, output tokens are cleared first, so max_tokens = 0 + 10 = 10 + assert session.max_tokens == 10 def test_update_request_as_session(self): scheduler = create_scheduler() @@ -155,8 +162,8 @@ def test_update_request_as_session(self): ) new_request.sampling_params = SamplingParams(max_tokens=10) - session.streaming_queue.append(StreamingUpdate.from_request(new_request)) - scheduler._update_request_as_session(session) + update = StreamingUpdate.from_request(new_request) + scheduler._update_request_as_session(session, update) assert session.prompt_token_ids == [1, 2, 3, 4, 5, 6] assert session._all_token_ids == [1, 2, 3, 4, 5, 6] @@ -190,8 +197,8 @@ def test_update_request_as_session_with_multimodal(self): prompt_token_ids=[4, 5, 6, 7], mm_features=[mm_feature], ) - session.streaming_queue.append(StreamingUpdate.from_request(new_request)) - scheduler._update_request_as_session(session) + update = StreamingUpdate.from_request(new_request) + scheduler._update_request_as_session(session, update) assert len(session.mm_features) == 2 assert session.mm_features[0].mm_position.offset == 1 @@ -199,6 +206,13 @@ def test_update_request_as_session_with_multimodal(self): assert session.mm_features[1].mm_position.offset == 5 def test_process_streaming_requests_with_finish_session(self): + """Test that a non-resumable request signals stream completion. + + With the new streaming API, completion is signaled by closing/finishing + the input generator. When a non-resumable request is added to a session + in WAITING_FOR_STREAMING_REQ state, the session is finished immediately + with FINISHED_ABORTED status. + """ scheduler = create_scheduler() session = DummyRequest( @@ -210,6 +224,7 @@ def test_process_streaming_requests_with_finish_session(self): session.status = RequestStatus.WAITING_FOR_STREAMING_REQ session.num_computed_tokens = len(session.prompt_token_ids) + # A non-resumable request signals stream completion close_request = DummyRequest( request_id="session", prompt_token_ids=[0], @@ -217,30 +232,25 @@ def test_process_streaming_requests_with_finish_session(self): max_tokens=1, ) scheduler.add_request(close_request) - assert close_request.status == RequestStatus.WAITING - assert len(session.streaming_queue) == 1 - sout = scheduler.schedule() - mro = ModelRunnerOutput( - req_ids=[session.request_id], - req_id_to_index={session.request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={session.request_id: None}, - pooler_output=None, - ) - out = scheduler.update_from_output(sout, mro) - assert session.status == RequestStatus.FINISHED_LENGTH_CAPPED - assert len(out) == 1 - assert out[0].outputs[0].request_id == session.request_id - assert out[0].outputs[0].resumable is False + # The session should be immediately finished (stream completed) + assert session.status == RequestStatus.FINISHED_ABORTED + # The session should be removed from the scheduler + assert session.request_id not in scheduler.requests def test_streaming_request_session_update(self): + """Test that a resumable request updates a waiting session directly. + + When a session is in WAITING_FOR_STREAMING_REQ state and a new resumable + request arrives, the update is applied directly via _update_request_as_session, + not queued. + """ scheduler = create_scheduler() session = DummyRequest( request_id="session", prompt_token_ids=[1, 2, 3], + resumable=True, ) scheduler.add_request(session) session.status = RequestStatus.WAITING_FOR_STREAMING_REQ @@ -253,13 +263,16 @@ def test_streaming_request_session_update(self): ) scheduler.add_request(next_request) - assert next_request.status == RequestStatus.WAITING - assert len(session.streaming_queue) == 1 + + # With the new behavior, when session is in WAITING_FOR_STREAMING_REQ, + # the update is applied directly (not queued), and session status + # becomes WAITING + assert session.status == RequestStatus.WAITING + assert session.prompt_token_ids == [1, 2, 3, 4, 5] _ = scheduler.schedule() assert session.status == RequestStatus.RUNNING - assert session.prompt_token_ids == [1, 2, 3, 4, 5] def test_update_request_as_session_with_output_tokens(self): scheduler = create_scheduler() @@ -280,14 +293,19 @@ def test_update_request_as_session_with_output_tokens(self): prompt_token_ids=[4, 5], ) - session.streaming_queue.append(StreamingUpdate.from_request(new_request)) - scheduler._update_request_as_session(session) + update = StreamingUpdate.from_request(new_request) + scheduler._update_request_as_session(session, update) - # Verify the last output token (11) was removed, and new prompt tokens added + # _update_request_as_session keeps computed output tokens (they become + # part of the prompt) and only discards the final uncomputed sampled + # token. Computed output token 10 is kept, uncomputed token 11 is + # discarded. assert session._all_token_ids == [1, 2, 3, 10, 4, 5] - assert session.prompt_token_ids == [1, 2, 3, 4, 5] - # Verify output tokens list is unchanged (only removed from _all_token_ids) - assert session._output_token_ids == [10, 11] + assert session.prompt_token_ids == [1, 2, 3, 10, 4, 5] + # Output tokens list is cleared + assert session._output_token_ids == [] + # num_computed_tokens is unchanged (KV cache still valid for computed + # tokens) assert session.num_computed_tokens == 4 # Verify that the next schedule will only process the new prompt tokens # num_new_tokens = num_tokens - num_computed_tokens = 6 - 4 = 2 @@ -369,9 +387,13 @@ def test_streaming_e2e_lifecycle(self): # Step 3: Simulate model runner caching the prompt_token_ids # This simulates gpu_model_runner.py:706-720 CachedRequestState creation + # The model runner makes a copy of prompt_token_ids when creating + # CachedRequestState cached_state_cycle1 = { "req_id": session.request_id, - "prompt_token_ids": new_req_data_cycle1.prompt_token_ids, # Must be a copy! + "prompt_token_ids": list( + new_req_data_cycle1.prompt_token_ids + ), # Explicit copy "output_token_ids": [], "num_computed_tokens": 0, } @@ -495,31 +517,39 @@ def test_streaming_e2e_lifecycle(self): prompt_token_ids=[4, 5], ) scheduler.add_request(new_request) - assert new_request.status == RequestStatus.WAITING - assert len(session.streaming_queue) == 1 - # Step 13: Scheduler merges new request into session and schedules + # With the new streaming API, when session is in WAITING_FOR_STREAMING_REQ, + # the update is applied directly via _update_request_as_session (not queued). + # The session status becomes WAITING after the update is applied. + assert session.status == RequestStatus.WAITING + + # Step 13: Scheduler schedules the updated session scheduler_output_cycle3 = scheduler.schedule() - # Verify scheduler created NewRequestData with merged _all_token_ids + # Verify scheduler created NewRequestData with merged prompt_token_ids assert len(scheduler_output_cycle3.scheduled_new_reqs) == 1 assert ( scheduler_output_cycle3.scheduled_new_reqs[0].prompt_token_ids - == session._all_token_ids + == session.prompt_token_ids ) assert ( scheduler_output_cycle3.num_scheduled_tokens[session.request_id] == 2 ) # Only new tokens [4, 5] - # STOP_TOKEN removed from _all_token_ids + # Computed output tokens are kept (become part of prompt), only the + # final uncomputed sampled token (STOP_TOKEN) is discarded assert session._all_token_ids == [1, 2, 3, 10, 4, 5] - assert session.prompt_token_ids == [1, 2, 3, 4, 5] # Only prompts - assert session._output_token_ids == [10, STOP_TOKEN] + assert session.prompt_token_ids == [1, 2, 3, 10, 4, 5] # Includes kept output + assert session._output_token_ids == [] # Output tokens are cleared # Step 14: Model runner caches NEW prompt_token_ids reference + # The model runner makes a copy of prompt_token_ids when creating + # CachedRequestState new_req_data_cycle3 = scheduler_output_cycle3.scheduled_new_reqs[0] cached_state_cycle3 = { "req_id": session.request_id, - "prompt_token_ids": new_req_data_cycle3.prompt_token_ids, + "prompt_token_ids": list( + new_req_data_cycle3.prompt_token_ids + ), # Explicit copy "output_token_ids": [], "num_computed_tokens": session.num_computed_tokens, } diff --git a/vllm/outputs.py b/vllm/outputs.py index cf23745c447d..5bd460aad464 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -192,6 +192,16 @@ def __repr__(self) -> str: ) +# Sentinel to indicate request is finished, used with streaming inputs. +STREAM_FINISHED = RequestOutput( + request_id="", + prompt=None, + prompt_token_ids=None, + prompt_logprobs=None, + outputs=[], + finished=True, +) + _O = TypeVar("_O", default=PoolingOutput) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8dedd756b050..2175882b9661 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,13 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools import time -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Iterable from dataclasses import replace from typing import Any import numpy as np -import torch from vllm import envs from vllm.compilation.cuda_graph import CUDAGraphStat @@ -51,10 +50,7 @@ from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.perf import ModelMetrics, PerfStats -from vllm.v1.metrics.stats import ( - PrefixCacheStats, - SchedulerStats, -) +from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus, StreamingUpdate from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -170,7 +166,7 @@ def __init__( # Counter for requests waiting for streaming input. Used to calculate # number of unfinished requests - self.num_waiting_for_streaming: int = 0 + self.num_waiting_for_streaming_input: int = 0 # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() @@ -511,13 +507,10 @@ def schedule(self) -> SchedulerOutput: # Streaming: skip request if still waiting for next streaming req. if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: - if request.streaming_queue: - # Updates the request status to WAITING. - self._update_request_as_session(request) - else: - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue + assert not request.streaming_queue + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue # Check that adding the request still respects the max_loras # constraint. @@ -754,7 +747,7 @@ def schedule(self) -> SchedulerOutput: scheduled_new_reqs = scheduled_new_reqs + scheduled_resumed_reqs scheduled_resumed_reqs = [] new_reqs_data = [ - self._make_new_request_data( + NewRequestData.from_request( req, req_to_new_blocks[req.request_id].get_block_ids(), req._all_token_ids, @@ -763,7 +756,7 @@ def schedule(self) -> SchedulerOutput: ] else: new_reqs_data = [ - self._make_new_request_data( + NewRequestData.from_request( req, req_to_new_blocks[req.request_id].get_block_ids() ) for req in scheduled_new_reqs @@ -869,27 +862,31 @@ def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: # it will also affect the scheduler output. self.finished_req_ids = set() - def _update_request_as_session(self, session: Request) -> None: + def _update_request_as_session( + self, session: Request, update: StreamingUpdate + ) -> None: """ Updates the waiting session with the next streaming update. - Removes the last output token (not yet scheduled) from `_all_token_ids` - because the new request's prompt tokens will replace it. Typically the decoded - outputs are scheduled as the next input in autoregressive decoding. When we - receive a new streaming request, the new prompt becomes our next input, so the - last output token is no longer needed and will not join the kv cache. This - ensures correct calculation of `num_new_tokens` in `schedule`. + Discards the last sampled output token from the prior input chunk. """ - assert session.streaming_queue is not None - update = session.streaming_queue.popleft() + if update is None: + # Streaming-input request finished. + self.finish_requests(session.request_id, RequestStatus.FINISHED_ABORTED) + return - num_new_tokens = session.num_tokens - session.num_computed_tokens - assert num_new_tokens in (0, 1), f"got {num_new_tokens=}" - if num_new_tokens == 1: - assert session._all_token_ids[-1] == session._output_token_ids[-1] - del session._all_token_ids[-1] + # Current streaming input behaviour: Keep only computed output tokens + # (discard final sampled output token). + num_computed_tokens = session.num_computed_tokens + kept_output_tokens = session._all_token_ids[ + session.num_prompt_tokens : num_computed_tokens + ] + del session._all_token_ids[num_computed_tokens:] + session._output_token_ids.clear() + assert session.prompt_token_ids is not None + # Extend prompt with kept output tokens. + session.prompt_token_ids.extend(kept_output_tokens) - session.resumable = update.resumable if update.mm_features: base = session.num_tokens for mm_feature in update.mm_features: @@ -899,50 +896,21 @@ def _update_request_as_session(self, session: Request) -> None: session.mm_features.extend(update.mm_features) session._all_token_ids.extend(update.prompt_token_ids or ()) - if session.prompt_token_ids is None: - session.prompt_token_ids = [] session.prompt_token_ids.extend(update.prompt_token_ids or ()) - if session.prompt_embeds is not None and update.prompt_embeds is not None: - session.prompt_embeds = torch.cat( - [session.prompt_embeds, update.prompt_embeds] - ) - elif update.prompt_embeds is not None: - session.prompt_embeds = update.prompt_embeds - session.max_tokens = session.num_output_tokens + update.max_tokens + # Update block hashes for the new tokens + # (mirrors Request.append_output_token_ids) + if session.get_hash_new_full_blocks is not None: + session.block_hashes.extend(session.get_hash_new_full_blocks()) + session.num_prompt_tokens = len(session.prompt_token_ids) session.arrival_time = update.arrival_time session.sampling_params = update.sampling_params if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ: - self.num_waiting_for_streaming -= 1 + self.num_waiting_for_streaming_input -= 1 session.status = RequestStatus.WAITING if self.log_stats: session.record_event(EngineCoreEventType.QUEUED) - def _make_new_request_data( - self, - request: Request, - block_ids: tuple[list[int], ...], - prefill_token_ids: list[int] | None = None, - ) -> NewRequestData: - """ - Creates NewRequestData for requests in waiting queue to be sent to - ModelRunner via SchedulerOutput.scheduled_new_reqs. - - For streaming requests, we send all tokens, including past inputs and - decoded outputs, through the prompt field. Updated streaming requests - create new entries in InputBatch, so we need the full input history to - ensure alignment of mm offsets, kv cache, and token ids. - - NOTE: Make sure that prompt_token_ids is a copy of the original request's - _all_token_ids. Since the scheduler updates _all_token_ids each iteration, the - corresponding prompt_token_ids reference in NewRequestData will be mistakenly - updated while decoding if we don't make a copy. - """ - req_data = NewRequestData.from_request(request, block_ids, prefill_token_ids) - if request.streaming_queue is not None: - req_data.prompt_token_ids = request._all_token_ids.copy() - return req_data - def _make_cached_request_data( self, running_reqs: list[Request], @@ -1285,7 +1253,12 @@ def update_from_output( stopped = True routed_experts = None + finish_reason = None if stopped: + # Capture finish_reason BEFORE _handle_stopped_request, which may + # reset the status to WAITING for streaming requests that continue. + finish_reason = request.get_finished_reason() + if self.vllm_config.model_config.enable_return_routed_experts: kv_blocks = self.kv_cache_manager.get_blocks(request.request_id) block_ids = kv_blocks.get_block_ids()[0] @@ -1309,15 +1282,10 @@ def update_from_output( indices=slot_mapping ) - if request.resumable: - if request.streaming_queue: - self._update_request_as_session(request) - else: - request.status = RequestStatus.WAITING_FOR_STREAMING_REQ - self.num_waiting_for_streaming += 1 - self.waiting.add_request(request) - else: + finished = self._handle_stopped_request(request) + if finished: kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: stopped_running_reqs.add(request) else: @@ -1354,7 +1322,7 @@ def update_from_output( EngineCoreOutput( request_id=req_id, new_token_ids=new_token_ids, - finish_reason=request.get_finished_reason(), + finish_reason=finish_reason, new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, pooling_output=pooler_output, @@ -1365,7 +1333,6 @@ def update_from_output( num_cached_tokens=request.num_cached_tokens, routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, - resumable=request.resumable, ) ) else: @@ -1450,6 +1417,24 @@ def update_from_output( return engine_core_outputs + def _handle_stopped_request(self, request: Request) -> bool: + """Return True if finished (can be False for resumable requests).""" + if not request.resumable: + return True + + if request.streaming_queue: + update = request.streaming_queue.popleft() + if update is None: + # Streaming request finished. + return True + self._update_request_as_session(request, update) + else: + request.status = RequestStatus.WAITING_FOR_STREAMING_REQ + self.num_waiting_for_streaming_input += 1 + + self.waiting.add_request(request) + return False + def _update_request_with_output( self, request: Request, new_token_ids: list[int] ) -> tuple[list[int], bool]: @@ -1552,11 +1537,16 @@ def get_request_counts(self) -> tuple[int, int]: def add_request(self, request: Request) -> None: existing = self.requests.get(request.request_id) - if existing is not None and existing.streaming_queue is not None: - existing.streaming_queue.append(StreamingUpdate.from_request(request)) - if self.log_stats: - existing.record_event(EngineCoreEventType.QUEUED) + if existing is not None: + update = StreamingUpdate.from_request(request) + if existing.status == RequestStatus.WAITING_FOR_STREAMING_REQ: + self._update_request_as_session(existing, update) + else: + assert existing.streaming_queue is not None + existing.streaming_queue.append(update) else: + if request.resumable: + request.streaming_queue = deque() self.waiting.add_request(request) self.requests[request.request_id] = request if self.log_stats: @@ -1592,7 +1582,7 @@ def finish_requests( running_requests_to_remove.add(request) else: if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: - self.num_waiting_for_streaming -= 1 + self.num_waiting_for_streaming_input -= 1 waiting_requests_to_remove.append(request) # Remove all requests from queues at once for better efficiency @@ -1627,7 +1617,7 @@ def _free_blocks(self, request: Request): del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: - num_waiting = len(self.waiting) - self.num_waiting_for_streaming + num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input return num_waiting + len(self.running) def has_finished_requests(self) -> bool: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 483b6c6eb435..e8e44746bf47 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -146,8 +146,6 @@ class EngineCoreOutput( # A value greater than 0 indicates that the output is corrupted. num_nans_in_logits: int = 0 - resumable: bool = False - @property def finished(self) -> bool: return self.finish_reason is not None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 186999b3055d..29def32b1744 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -8,11 +8,12 @@ from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy from dataclasses import dataclass -from typing import Any, cast +from typing import Any import torch import vllm.envs as envs +from vllm import TokensPrompt from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient @@ -21,10 +22,10 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tracing import init_tracer @@ -38,6 +39,7 @@ from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest +from vllm.v1.engine.utils import get_prompt_text from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import ( StatLoggerFactory, @@ -58,9 +60,20 @@ class StreamingInput: where inputs are provided via an async generator. """ - prompt: EngineCoreRequest | PromptType - sampling_params: SamplingParams - resumable: bool = True + prompt: PromptType + sampling_params: SamplingParams | None = None + + +class InputStreamError(Exception): + """Wrapper for errors from the input stream generator. + + This is used to propagate errors from the user's input generator + without wrapping them in EngineGenerateError. + """ + + def __init__(self, cause: Exception): + self.cause = cause + super().__init__(str(cause)) class AsyncLLM(EngineClient): @@ -273,7 +286,7 @@ async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def add_request( self, request_id: str, - prompt: EngineCoreRequest | PromptType, + prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], params: SamplingParams | PoolingParams, arrival_time: float | None = None, lora_request: LoRARequest | None = None, @@ -282,7 +295,6 @@ async def add_request( priority: int = 0, data_parallel_rank: int | None = None, prompt_text: str | None = None, - resumable: bool = False, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -310,6 +322,20 @@ async def add_request( tokenization_kwargs, ) + if isinstance(prompt, AsyncGenerator): + # Streaming input case. + return await self._add_streaming_input_request( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + data_parallel_rank, + ) + # Convert Input --> Request. if isinstance(prompt, EngineCoreRequest): request = prompt @@ -319,8 +345,8 @@ async def add_request( "does not match the EngineCoreRequest.request_id attribute. The " "latter will be used, and the former will be ignored." ) - else: - if prompt_text is not None: + elif not isinstance(prompt, AsyncGenerator): + if prompt_text is None: raise ValueError( "should only provide prompt_text with EngineCoreRequest" ) @@ -334,12 +360,8 @@ async def add_request( trace_headers, priority, data_parallel_rank, - resumable, ) - if isinstance(prompt, str): - prompt_text = prompt - elif isinstance(prompt, Mapping): - prompt_text = cast(str | None, prompt.get("prompt")) + prompt_text = get_prompt_text(prompt) self.input_processor.assign_request_id(request) @@ -352,12 +374,8 @@ async def add_request( async with self._pause_cond: await self._pause_cond.wait_for(lambda: not self._paused) - # Reuse output collector for streaming session, create new otherwise. - existing_state = self.output_processor.request_states.get(request_id) - if existing_state and existing_state.queue: - queue = existing_state.queue - else: - queue = RequestOutputCollector(params.output_kind, request.request_id) + # Create a new output collector for the request. + queue = RequestOutputCollector(params.output_kind, request.request_id) # Use cloned params that may have been updated in process_inputs() params = request.params @@ -398,61 +416,103 @@ async def _add_request( if self.log_requests: logger.info("Added request %s.", request.request_id) - async def _add_streaming_request( + async def _add_streaming_input_request( self, request_id: str, input_stream: AsyncGenerator[StreamingInput, None], - streaming_done: asyncio.Event, - ) -> tuple[RequestOutputCollector, asyncio.Task | None, list[int]]: - """Handle async generator input stream for streaming sessions.""" - - first_input = await input_stream.__anext__() - - if ( - self.vllm_config.cache_config.kv_sharing_fast_prefill - and first_input.sampling_params.prompt_logprobs - ): - raise ValueError( - "--kv-sharing-fast-prefill produces incorrect logprobs for " - "prompt tokens, please disable it when the requests need " - "prompt logprobs" - ) + sampling_params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, + priority: int = 0, + data_parallel_rank: int | None = None, + ) -> RequestOutputCollector: + self._validate_streaming_input_sampling_params(sampling_params) + + inputs = dict( + arrival_time=arrival_time, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + data_parallel_rank=data_parallel_rank, + ) - queue = await self.add_request( - request_id, - first_input.prompt, - first_input.sampling_params, - resumable=first_input.resumable, + if not sampling_params.skip_clone: + sampling_params = sampling_params.clone() + sampling_params.skip_clone = True + + # Create request for validation, also used as the finished signal + # once the input stream is closed. + final_req = self.input_processor.process_inputs( + request_id=request_id, + prompt=TokensPrompt(prompt_token_ids=[0]), + params=sampling_params, + **inputs, # type: ignore[arg-type] ) + self.input_processor.assign_request_id(final_req) + internal_req_id = final_req.request_id - # Track pending outputs: [count]. Use list for mutability in closure. - # Starts at 1 for first_input. Background task increments for each - # additional input. Main loop decrements when finished=True received. - pending_outputs = [1] + queue = RequestOutputCollector(sampling_params.output_kind, internal_req_id) - if first_input.resumable: - # Start background task to process remaining inputs - async def process_remaining(): - try: - async for streaming_input in input_stream: - pending_outputs[0] += 1 - await self.add_request( - request_id, - streaming_input.prompt, - streaming_input.sampling_params, - resumable=streaming_input.resumable, + async def handle_inputs(): + cancelled = False + try: + async for input_chunk in input_stream: + sp = input_chunk.sampling_params + if sp: + self._validate_streaming_input_sampling_params(sp) + else: + sp = sampling_params + req = self.input_processor.process_inputs( + request_id=internal_req_id, + prompt=input_chunk.prompt, + params=sp, + resumable=True, + **inputs, # type: ignore[arg-type] + ) + req.external_req_id = request_id + if req.prompt_embeds is not None: + raise ValueError( + "prompt_embeds not supported for streaming inputs" ) - if not streaming_input.resumable: - break - finally: - streaming_done.set() + prompt_text = get_prompt_text(input_chunk.prompt) + await self._add_request(req, prompt_text, None, 0, queue) + except (asyncio.CancelledError, GeneratorExit): + cancelled = True + except Exception as error: + # Wrap in InputStreamError so generate() can propagate it + # without wrapping in EngineGenerateError. + queue.put(InputStreamError(error)) + finally: + queue._input_stream_task = None + if not cancelled: + # Send empty final request to indicate that inputs have + # finished. Don't send if cancelled (session was aborted). + await self._add_request(final_req, None, None, 0, queue) + + # Ensure output handler is running. + self._run_output_handler() - input_task = asyncio.create_task(process_remaining()) - else: - streaming_done.set() - input_task = None + queue._input_stream_task = asyncio.create_task(handle_inputs()) + return queue - return queue, input_task, pending_outputs + @staticmethod + def _validate_streaming_input_sampling_params( + params: SamplingParams | PoolingParams, + ): + if ( + not isinstance(params, SamplingParams) + or params.n > 1 + or params.output_kind == RequestOutputKind.FINAL_ONLY + or params.stop + ): + raise ValueError( + "Input streaming not currently supported " + "for pooling models, n > 1, request_kind = FINAL_ONLY " + "or with stop strings." + ) # TODO: we should support multiple prompts in one call, as you # can do with LLM.generate. So that for multi-prompt completion @@ -462,7 +522,7 @@ async def process_remaining(): async def generate( self, prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], - sampling_params: SamplingParams | None, + sampling_params: SamplingParams, request_id: str, *, prompt_text: str | None = None, @@ -485,50 +545,26 @@ async def generate( The caller of generate() iterates the returned AsyncGenerator, returning the RequestOutput back to the caller. - - For streaming sessions with an async generator input, pass an - AsyncGenerator[StreamingInput, None] as prompt. In this case, - sampling_params can be None as each StreamingInput has its own params. """ q: RequestOutputCollector | None = None - streaming_done: asyncio.Event | None = None - input_task: asyncio.Task | None = None - pending_outputs: list[int] | None = None try: - if isinstance(prompt, AsyncGenerator): - streaming_done = asyncio.Event() - q, input_task, pending_outputs = await self._add_streaming_request( - request_id, prompt, streaming_done - ) - else: - if sampling_params is None: - raise ValueError( - "sampling_params is required when prompt is not an " - "AsyncGenerator[StreamingInput, None]" - ) - q = await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - trace_headers=trace_headers, - priority=priority, - data_parallel_rank=data_parallel_rank, - prompt_text=prompt_text, - ) + q = await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + data_parallel_rank=data_parallel_rank, + prompt_text=prompt_text, + ) # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. - while True: - if ( - input_task is not None - and input_task.done() - and (exc := input_task.exception()) - ): - raise exc - + finished = False + while not finished: # Note: drain queue without await if possible (avoids # task switching under load which helps performance). out = q.get_nowait() or await q.get() @@ -536,21 +572,14 @@ async def generate( # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. assert isinstance(out, RequestOutput) - yield out - - if out.finished: - if pending_outputs is not None: - pending_outputs[0] -= 1 - # Exit when no more pending outputs - if pending_outputs is None or pending_outputs[0] == 0: - break + finished = out.finished + if out is not STREAM_FINISHED: + yield out # If the request is disconnected by the client, generate() # is cancelled or the generator is garbage collected. So, # we abort the request if we end up here. except (asyncio.CancelledError, GeneratorExit): - if input_task is not None: - input_task.cancel() if q is not None: await self.abort(q.request_id, internal=True) if self.log_requests: @@ -559,24 +588,26 @@ async def generate( # Engine is dead. Do not abort since we shut down. except EngineDeadError: - if input_task is not None: - input_task.cancel() if self.log_requests: logger.info("Request %s failed (engine dead).", request_id) raise # Request validation error. except ValueError as e: - if input_task is not None: - input_task.cancel() if self.log_requests: logger.info("Request %s failed (bad request): %s.", request_id, e) raise + # Error from input stream generator - propagate directly. + except InputStreamError as e: + if q is not None: + await self.abort(q.request_id, internal=True) + if self.log_requests: + logger.info("Request %s failed (input error): %s.", request_id, e) + raise e.cause from e + # Unexpected error in the generate() task (possibly recoverable). except Exception as e: - if input_task is not None: - input_task.cancel() if q is not None: await self.abort(q.request_id, internal=True) if self.log_requests: @@ -590,6 +621,9 @@ async def generate( ) logger.info("Request %s failed due to %s.", request_id, s) raise EngineGenerateError() from e + finally: + if q is not None: + q.close() def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" @@ -815,6 +849,9 @@ async def encode( if self.log_requests: logger.info("Request %s failed.", request_id) raise EngineGenerateError() from e + finally: + if q is not None: + q.close() @property def tokenizer(self) -> TokenizerLike | None: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index ca7a44b78ca5..c497468bbd3e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Iterable from dataclasses import dataclass from typing import Any, cast @@ -12,6 +12,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import ( + STREAM_FINISHED, CompletionOutput, PoolingOutput, PoolingRequestOutput, @@ -51,6 +52,8 @@ def __init__(self, output_kind: RequestOutputKind, request_id: str): self.output: RequestOutput | PoolingRequestOutput | Exception | None = None self.ready = asyncio.Event() + self._input_stream_task: asyncio.Task | None = None + def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): @@ -87,6 +90,16 @@ def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None: raise output return output + def close(self): + if self._input_stream_task is not None: + self._input_stream_task.cancel() + self._input_stream_task = None + + def __del__(self): + if (task := self._input_stream_task) is not None: + task.get_loop().call_soon_threadsafe(task.cancel) + self._input_stream_task = None + @dataclass class OutputProcessorOutput: @@ -94,6 +107,20 @@ class OutputProcessorOutput: reqs_to_abort: list[str] +@dataclass +class StreamingUpdate: + """Streaming input update data for output processor. + + Contains the incremental prompt data to be applied to a request state + when the current sub-request completes. + """ + + prompt: str | None + prompt_token_ids: list[int] | None + arrival_time: float + final: bool = False + + class RequestState: def __init__( self, @@ -116,6 +143,7 @@ def __init__( top_p: float | None = None, n: int | None = None, temperature: float | None = None, + stream_input: bool = False, ): self.request_id = request_id self.external_req_id = external_req_id @@ -146,6 +174,31 @@ def __init__( self.stream_interval = stream_interval self.sent_tokens_offset = 0 # Offset of sent tokens + # Streaming input queue + self.streaming_input = stream_input + self.input_chunk_queue: deque[StreamingUpdate] | None = ( + deque() if stream_input else None + ) + + def apply_streaming_update(self, update: StreamingUpdate) -> None: + # Apply the update to the request state. + self.streaming_input = not update.final + # TODO also include relevant output tokens in new prompt here + # (match scheduler behavior). + if update.prompt: + self.prompt = ( + (self.prompt + update.prompt) if self.prompt else update.prompt + ) + if self.prompt_token_ids: + self.prompt_token_ids.extend(update.prompt_token_ids or ()) + else: + self.prompt_token_ids = update.prompt_token_ids or [] + assert self.prompt_token_ids is not None + self.prompt_len = len(self.prompt_token_ids) + if self.stats is not None: + self.stats.arrival_time = update.arrival_time + self.is_prefilling = True + @classmethod def from_new_request( cls, @@ -205,6 +258,7 @@ def from_new_request( queue=queue, log_stats=log_stats, stream_interval=stream_interval, + stream_input=request.resumable, ) def make_request_output( @@ -405,7 +459,6 @@ def abort_requests(self, request_ids: Iterable[str], internal: bool) -> list[str a parent request, in which case the associated child requests are aborted also. """ - internal_req_ids = [] for request_id in request_ids: if internal: @@ -464,8 +517,9 @@ def add_request( queue: RequestOutputCollector | None = None, ) -> None: request_id = request.request_id - if request_id in self.request_states: - self._update_streaming_request_state(request, prompt) + req_state = self.request_states.get(request_id) + if req_state is not None: + self._update_streaming_request_state(req_state, request, prompt) return req_state = RequestState.from_new_request( @@ -488,29 +542,37 @@ def add_request( self.external_req_ids[req_state.external_req_id].append(request_id) def _update_streaming_request_state( - self, request: EngineCoreRequest, prompt: str | None + self, req_state: RequestState, request: EngineCoreRequest, prompt: str | None ) -> None: - req_state = self.request_states[request.request_id] - if req_state.prompt and prompt: - req_state.prompt += prompt - elif prompt: - req_state.prompt = prompt - if request.prompt_token_ids: - if req_state.prompt_token_ids is None: - req_state.prompt_token_ids = [] - req_state.prompt_token_ids.extend(request.prompt_token_ids) - if req_state.prompt_embeds is not None and request.prompt_embeds is not None: - req_state.prompt_embeds = torch.cat( - [req_state.prompt_embeds, request.prompt_embeds] - ) - elif request.prompt_embeds is not None: - req_state.prompt_embeds = request.prompt_embeds - req_state.prompt_len = length_from_prompt_token_ids_or_embeds( - req_state.prompt_token_ids, req_state.prompt_embeds + """Queue a streaming update instead of immediately applying it.""" + if not request.resumable: + # Final request - just mark completion, don't add its dummy tokens. + if req_state.input_chunk_queue is None: + # Engine already finished - emit final output and clean up. + self._finish_request(req_state) + if req_state.queue is not None: + # Emit a final output with finished=True + # to unblock the generate() loop. + req_state.queue.put(STREAM_FINISHED) + elif req_state.input_chunk_queue: + req_state.input_chunk_queue[-1].final = True + else: + req_state.streaming_input = False + return + + update = StreamingUpdate( + prompt=prompt, + prompt_token_ids=request.prompt_token_ids, + arrival_time=request.arrival_time, ) - if req_state.stats is not None: - req_state.stats.arrival_time = request.arrival_time - req_state.is_prefilling = True + + # Apply request updates now if the last input already completed. + if req_state.input_chunk_queue is None: + req_state.apply_streaming_update(update) + req_state.input_chunk_queue = deque() + else: + # Queue the streaming update otherwise. + req_state.input_chunk_queue.append(update) def process_outputs( self, @@ -587,6 +649,9 @@ def process_outputs( kv_transfer_params, routed_experts, ): + if req_state.streaming_input: + request_output.finished = False + if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) @@ -595,37 +660,49 @@ def process_outputs( request_outputs.append(request_output) # Free completed requests. - if finish_reason is not None and not engine_core_output.resumable: - self.request_states.pop(req_id) - - internal_ids = self.external_req_ids[req_state.external_req_id] - internal_ids.remove(req_id) - if not internal_ids: - del self.external_req_ids[req_state.external_req_id] - - # Remove parent request if applicable. - parent_req = req_state.parent_req - if parent_req and not parent_req.child_requests: - self.parent_requests.pop(parent_req.request_id, None) - if not self.request_states: - self._requests_drained.set() - if not engine_core_output.finished: - # If req not finished in EngineCore, but Detokenizer - # detected stop string, abort needed in EngineCore. - reqs_to_abort.append(req_id) - - # Track per-request stats - self._update_stats_from_finished( - req_state, finish_reason, iteration_stats - ) - if self.tracer: - self.do_tracing(engine_core_output, req_state, iteration_stats) + if finish_reason is not None: + if req_state.streaming_input: + if req_state.input_chunk_queue: + update = req_state.input_chunk_queue.popleft() + req_state.apply_streaming_update(update) + else: + req_state.input_chunk_queue = None + else: + self._finish_request(req_state) + if not engine_core_output.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) + + # Track per-request stats + self._update_stats_from_finished( + req_state, finish_reason, iteration_stats + ) + if self.tracer: + self.do_tracing(engine_core_output, req_state, iteration_stats) return OutputProcessorOutput( request_outputs=request_outputs, reqs_to_abort=reqs_to_abort, ) + def _finish_request(self, req_state: RequestState) -> None: + req_id = req_state.request_id + self.request_states.pop(req_id) + + internal_ids = self.external_req_ids[req_state.external_req_id] + internal_ids.remove(req_id) + if not internal_ids: + del self.external_req_ids[req_state.external_req_id] + + # Remove parent request if applicable. + parent_req = req_state.parent_req + if parent_req and not parent_req.child_requests: + self.parent_requests.pop(parent_req.request_id, None) + + if not self.request_states: + self._requests_drained.set() + def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None): self.lora_states.update_scheduler_stats(scheduler_stats) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 5db3a53266f0..4056c225c907 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -4,12 +4,12 @@ import contextlib import os import weakref -from collections.abc import Callable, Iterator +from collections.abc import Callable, Iterator, Mapping from dataclasses import dataclass from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import msgspec @@ -224,6 +224,14 @@ def get_device_indices( return value +def get_prompt_text(prompt: Any) -> str | None: + if isinstance(prompt, str): + return prompt + if isinstance(prompt, Mapping): + return cast(str | None, prompt.get("prompt")) + return None + + class CoreEngineActorManager: """ Utility class to handle creation, readiness, and shutdown diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 432510ad2b4d..b963fea43df5 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -37,21 +37,19 @@ class StreamingUpdate: with new input data. """ - resumable: bool mm_features: list[MultiModalFeatureSpec] | None prompt_token_ids: list[int] | None - prompt_embeds: torch.Tensor | None max_tokens: int arrival_time: float sampling_params: SamplingParams | None @classmethod - def from_request(cls, request: "Request") -> "StreamingUpdate": + def from_request(cls, request: "Request") -> "StreamingUpdate | None": + if not request.resumable: + return None return cls( - resumable=request.resumable, mm_features=request.mm_features, prompt_token_ids=request.prompt_token_ids, - prompt_embeds=request.prompt_embeds, max_tokens=request.max_tokens, arrival_time=request.arrival_time, sampling_params=request.sampling_params, @@ -116,6 +114,9 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.prompt_embeds = prompt_embeds + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds + ) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = ( self.prompt_token_ids.copy() @@ -166,9 +167,8 @@ def __init__( # Used for streaming self.resumable = resumable - self.streaming_queue: deque[StreamingUpdate] | None = ( - deque() if resumable else None - ) + # None entry in the queue means finished. + self.streaming_queue: deque[StreamingUpdate | None] | None = None @classmethod def from_engine_core_request( @@ -224,12 +224,6 @@ def num_tokens_with_spec(self) -> int: def num_output_tokens(self) -> int: return len(self._output_token_ids) - @property - def num_prompt_tokens(self) -> int: - return length_from_prompt_token_ids_or_embeds( - self.prompt_token_ids, self.prompt_embeds - ) - @property def num_encoder_inputs(self) -> int: return len(self.mm_features) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 8d5ef819e5a6..662badeb5f1a 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -54,15 +54,13 @@ class CachedRequestState: pooling_states: PoolingStates | None = None def __post_init__(self): - if self.pooling_params is not None: - self.pooling_states = PoolingStates() - - @property - def num_prompt_tokens(self) -> int: - return length_from_prompt_token_ids_or_embeds( + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds ) + if self.pooling_params is not None: + self.pooling_states = PoolingStates() + @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 36f3b234fe6c..143a3179fcc1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1120,9 +1120,7 @@ def _update_streaming_request( NOTE: prompt_token_ids includes intermediate output tokens - tokens previously generated but now are input context (part of the prompt). """ - if req_id in self.input_batch.req_id_to_index: - self.input_batch.remove_request(req_id) - + self.input_batch.remove_request(req_id) req_state = self.requests[req_id] req_state.prompt_token_ids = new_req_data.prompt_token_ids @@ -1132,6 +1130,9 @@ def _update_streaming_request( req_state.pooling_params = new_req_data.pooling_params req_state.block_ids = new_req_data.block_ids req_state.num_computed_tokens = new_req_data.num_computed_tokens + req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds + ) # Clear `output_token_ids` as previous output tokens are now part of # `prompt_token_ids`.