diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 72d468db08f6..0d5836fab5a7 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -275,7 +275,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_streaming(client: OpenAI, model_name: str): +@pytest.mark.parametrize("background", [True, False]) +async def test_streaming(client: OpenAI, model_name: str, background: bool): # TODO: Add back when web search and code interpreter are available in CI prompts = [ "tell me a story about a cat in 20 words", @@ -300,11 +301,16 @@ async def test_streaming(client: OpenAI, model_name: str): # }, ], stream=True, + background=background, ) events = [] current_event_mode = None + resp_id = None async for event in response: + if event.type == "response.created": + resp_id = event.response.id + if current_event_mode != event.type: current_event_mode = event.type print(f"\n[{event.type}] ", end="", flush=True) @@ -322,6 +328,17 @@ async def test_streaming(client: OpenAI, model_name: str): assert len(events) > 0 + if background: + starting_after = 5 + async with await client.responses.retrieve( + response_id=resp_id, + stream=True, + starting_after=starting_after) as stream: + counter = starting_after + async for event in stream: + counter += 1 + assert event == events[counter] + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3cebfdf885be..b6667ebf152e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): @router.get("/v1/responses/{response_id}") -async def retrieve_responses(response_id: str, raw_request: Request): +async def retrieve_responses( + response_id: str, + raw_request: Request, + starting_after: Optional[int] = None, + stream: Optional[bool] = False, +): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( message="The model does not support Responses API") try: - response = await handler.retrieve_responses(response_id) + response = await handler.retrieve_responses( + response_id, + starting_after=starting_after, + stream=stream, + ) except Exception as e: raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e @@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request): if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.error.code) + elif stream: + return StreamingResponse(content=response, + media_type="text/event-stream") return JSONResponse(content=response.model_dump()) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 7f11b37e5172..58424c9d9f7b 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -4,6 +4,7 @@ import asyncio import json import time +from collections import deque from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AsyncExitStack from copy import copy @@ -55,7 +56,7 @@ # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.tool_server import MCPToolServer, ToolServer +from vllm.entrypoints.tool_server import ToolServer from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob @@ -168,6 +169,11 @@ def __init__( # never remove messages from the store. self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} + # HACK(wuhang): This is a hack. We should use a better store. + # FIXME: If enable_store=True, this may cause a memory leak since we + # never remove events from the store. + self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {} + self.background_tasks: dict[str, asyncio.Task] = {} self.tool_server = tool_server @@ -249,15 +255,6 @@ async def create_responses( if raw_request: raw_request.state.request_metadata = request_metadata - if self.tool_server is not None and isinstance( - self.tool_server, - MCPToolServer) and request.stream and request.tools and any( - tool.type in ["web_search_preview", "code_interpreter"] - for tool in request.tools): - return self.create_error_response( - "MCP tool server is not supported in background mode and " - "streaming mode") - # Schedule the request and get the result generator. generators: list[AsyncGenerator[ConversationContext, None]] = [] @@ -329,25 +326,44 @@ async def create_responses( self.response_store[response.id] = response # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) + if request.stream: + task = asyncio.create_task( + self._run_background_request_stream( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{request.request_id}", + ) + else: + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) # For cleanup. response_id = response.id self.background_tasks[response_id] = task task.add_done_callback( lambda _: self.background_tasks.pop(response_id, None)) + + if request.stream: + return self.responses_background_stream_generator( + request.request_id) return response if request.stream: @@ -736,6 +752,40 @@ def _construct_input_messages_with_harmony( prev_outputs.append(response_msg) return messages + async def _run_background_request_stream( + self, + request: ResponsesRequest, + *args, + **kwargs, + ): + event_deque: deque[str] = deque() + new_event_signal = asyncio.Event() + self.event_store[request.request_id] = (event_deque, new_event_signal) + response = None + try: + generator = self.responses_stream_generator( + request, *args, **kwargs) + async for event in generator: + event_deque.append(event) + new_event_signal.set() # Signal new event available + except Exception as e: + logger.exception("Background request failed for %s", + request.request_id) + response = self.create_error_response(str(e)) + finally: + # Mark as finished with a special marker + event_deque.append("__STREAM_END__") + new_event_signal.set() + + if response is not None and isinstance(response, ErrorResponse): + # If the request has failed, update the status to "failed". + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + async def _run_background_request( self, request: ResponsesRequest, @@ -759,9 +809,36 @@ async def _run_background_request( if stored_response.status not in ("completed", "cancelled"): stored_response.status = "failed" + async def responses_background_stream_generator( + self, + response_id: str, + starting_after: Optional[int] = None, + ): + if response_id not in self.event_store: + raise ValueError(f"Unknown response_id: {response_id}") + + event_deque, new_event_signal = self.event_store[response_id] + start_index = 0 if starting_after is None else starting_after + 1 + current_index = start_index + + while True: + new_event_signal.clear() + + # Yield existing events from start_index + while current_index < len(event_deque): + event = event_deque[current_index] + if event == "__STREAM_END__": + return + yield event + current_index += 1 + + await new_event_signal.wait() + async def retrieve_responses( self, response_id: str, + starting_after: Optional[int], + stream: Optional[bool], ) -> Union[ErrorResponse, ResponsesResponse]: if not response_id.startswith("resp_"): return self._make_invalid_id_error(response_id) @@ -771,6 +848,12 @@ async def retrieve_responses( if response is None: return self._make_not_found_error(response_id) + + if stream: + return self.responses_background_stream_generator( + response_id, + starting_after, + ) return response async def cancel_responses(