Skip to content

Commit 30e62c8

Browse files
wuhang2014eicherseiji
authored andcommitted
[Feature][Responses API]Support MCP tools with streaming mode + background mode (vllm-project#23927)
Signed-off-by: wuhang <[email protected]>
1 parent d84bdb0 commit 30e62c8

File tree

3 files changed

+138
-26
lines changed

3 files changed

+138
-26
lines changed

tests/entrypoints/openai/test_response_api_with_harmony.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
275275

276276
@pytest.mark.asyncio
277277
@pytest.mark.parametrize("model_name", [MODEL_NAME])
278-
async def test_streaming(client: OpenAI, model_name: str):
278+
@pytest.mark.parametrize("background", [True, False])
279+
async def test_streaming(client: OpenAI, model_name: str, background: bool):
279280
# TODO: Add back when web search and code interpreter are available in CI
280281
prompts = [
281282
"tell me a story about a cat in 20 words",
@@ -300,11 +301,16 @@ async def test_streaming(client: OpenAI, model_name: str):
300301
# },
301302
],
302303
stream=True,
304+
background=background,
303305
)
304306

305307
events = []
306308
current_event_mode = None
309+
resp_id = None
307310
async for event in response:
311+
if event.type == "response.created":
312+
resp_id = event.response.id
313+
308314
if current_event_mode != event.type:
309315
current_event_mode = event.type
310316
print(f"\n[{event.type}] ", end="", flush=True)
@@ -322,6 +328,17 @@ async def test_streaming(client: OpenAI, model_name: str):
322328

323329
assert len(events) > 0
324330

331+
if background:
332+
starting_after = 5
333+
async with await client.responses.retrieve(
334+
response_id=resp_id,
335+
stream=True,
336+
starting_after=starting_after) as stream:
337+
counter = starting_after
338+
async for event in stream:
339+
counter += 1
340+
assert event == events[counter]
341+
325342

326343
@pytest.mark.asyncio
327344
@pytest.mark.parametrize("model_name", [MODEL_NAME])

vllm/entrypoints/openai/api_server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,21 +616,33 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
616616

617617

618618
@router.get("/v1/responses/{response_id}")
619-
async def retrieve_responses(response_id: str, raw_request: Request):
619+
async def retrieve_responses(
620+
response_id: str,
621+
raw_request: Request,
622+
starting_after: Optional[int] = None,
623+
stream: Optional[bool] = False,
624+
):
620625
handler = responses(raw_request)
621626
if handler is None:
622627
return base(raw_request).create_error_response(
623628
message="The model does not support Responses API")
624629

625630
try:
626-
response = await handler.retrieve_responses(response_id)
631+
response = await handler.retrieve_responses(
632+
response_id,
633+
starting_after=starting_after,
634+
stream=stream,
635+
)
627636
except Exception as e:
628637
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
629638
detail=str(e)) from e
630639

631640
if isinstance(response, ErrorResponse):
632641
return JSONResponse(content=response.model_dump(),
633642
status_code=response.error.code)
643+
elif stream:
644+
return StreamingResponse(content=response,
645+
media_type="text/event-stream")
634646
return JSONResponse(content=response.model_dump())
635647

636648

vllm/entrypoints/openai/serving_responses.py

Lines changed: 106 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import json
66
import time
7+
from collections import deque
78
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
89
from contextlib import AsyncExitStack
910
from copy import copy
@@ -55,7 +56,7 @@
5556
# yapf: enable
5657
from vllm.entrypoints.openai.serving_engine import OpenAIServing
5758
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
58-
from vllm.entrypoints.tool_server import MCPToolServer, ToolServer
59+
from vllm.entrypoints.tool_server import ToolServer
5960
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
6061
from vllm.logger import init_logger
6162
from vllm.logprobs import Logprob as SampleLogprob
@@ -168,6 +169,11 @@ def __init__(
168169
# never remove messages from the store.
169170
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
170171

172+
# HACK(wuhang): This is a hack. We should use a better store.
173+
# FIXME: If enable_store=True, this may cause a memory leak since we
174+
# never remove events from the store.
175+
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
176+
171177
self.background_tasks: dict[str, asyncio.Task] = {}
172178

173179
self.tool_server = tool_server
@@ -249,15 +255,6 @@ async def create_responses(
249255
if raw_request:
250256
raw_request.state.request_metadata = request_metadata
251257

252-
if self.tool_server is not None and isinstance(
253-
self.tool_server,
254-
MCPToolServer) and request.stream and request.tools and any(
255-
tool.type in ["web_search_preview", "code_interpreter"]
256-
for tool in request.tools):
257-
return self.create_error_response(
258-
"MCP tool server is not supported in background mode and "
259-
"streaming mode")
260-
261258
# Schedule the request and get the result generator.
262259
generators: list[AsyncGenerator[ConversationContext, None]] = []
263260

@@ -329,25 +326,44 @@ async def create_responses(
329326
self.response_store[response.id] = response
330327

331328
# Run the request in the background.
332-
task = asyncio.create_task(
333-
self._run_background_request(
334-
request,
335-
sampling_params,
336-
result_generator,
337-
context,
338-
model_name,
339-
tokenizer,
340-
request_metadata,
341-
created_time,
342-
),
343-
name=f"create_{response.id}",
344-
)
329+
if request.stream:
330+
task = asyncio.create_task(
331+
self._run_background_request_stream(
332+
request,
333+
sampling_params,
334+
result_generator,
335+
context,
336+
model_name,
337+
tokenizer,
338+
request_metadata,
339+
created_time,
340+
),
341+
name=f"create_{request.request_id}",
342+
)
343+
else:
344+
task = asyncio.create_task(
345+
self._run_background_request(
346+
request,
347+
sampling_params,
348+
result_generator,
349+
context,
350+
model_name,
351+
tokenizer,
352+
request_metadata,
353+
created_time,
354+
),
355+
name=f"create_{response.id}",
356+
)
345357

346358
# For cleanup.
347359
response_id = response.id
348360
self.background_tasks[response_id] = task
349361
task.add_done_callback(
350362
lambda _: self.background_tasks.pop(response_id, None))
363+
364+
if request.stream:
365+
return self.responses_background_stream_generator(
366+
request.request_id)
351367
return response
352368

353369
if request.stream:
@@ -736,6 +752,40 @@ def _construct_input_messages_with_harmony(
736752
prev_outputs.append(response_msg)
737753
return messages
738754

755+
async def _run_background_request_stream(
756+
self,
757+
request: ResponsesRequest,
758+
*args,
759+
**kwargs,
760+
):
761+
event_deque: deque[str] = deque()
762+
new_event_signal = asyncio.Event()
763+
self.event_store[request.request_id] = (event_deque, new_event_signal)
764+
response = None
765+
try:
766+
generator = self.responses_stream_generator(
767+
request, *args, **kwargs)
768+
async for event in generator:
769+
event_deque.append(event)
770+
new_event_signal.set() # Signal new event available
771+
except Exception as e:
772+
logger.exception("Background request failed for %s",
773+
request.request_id)
774+
response = self.create_error_response(str(e))
775+
finally:
776+
# Mark as finished with a special marker
777+
event_deque.append("__STREAM_END__")
778+
new_event_signal.set()
779+
780+
if response is not None and isinstance(response, ErrorResponse):
781+
# If the request has failed, update the status to "failed".
782+
response_id = request.request_id
783+
async with self.response_store_lock:
784+
stored_response = self.response_store.get(response_id)
785+
assert stored_response is not None
786+
if stored_response.status not in ("completed", "cancelled"):
787+
stored_response.status = "failed"
788+
739789
async def _run_background_request(
740790
self,
741791
request: ResponsesRequest,
@@ -759,9 +809,36 @@ async def _run_background_request(
759809
if stored_response.status not in ("completed", "cancelled"):
760810
stored_response.status = "failed"
761811

812+
async def responses_background_stream_generator(
813+
self,
814+
response_id: str,
815+
starting_after: Optional[int] = None,
816+
):
817+
if response_id not in self.event_store:
818+
raise ValueError(f"Unknown response_id: {response_id}")
819+
820+
event_deque, new_event_signal = self.event_store[response_id]
821+
start_index = 0 if starting_after is None else starting_after + 1
822+
current_index = start_index
823+
824+
while True:
825+
new_event_signal.clear()
826+
827+
# Yield existing events from start_index
828+
while current_index < len(event_deque):
829+
event = event_deque[current_index]
830+
if event == "__STREAM_END__":
831+
return
832+
yield event
833+
current_index += 1
834+
835+
await new_event_signal.wait()
836+
762837
async def retrieve_responses(
763838
self,
764839
response_id: str,
840+
starting_after: Optional[int],
841+
stream: Optional[bool],
765842
) -> Union[ErrorResponse, ResponsesResponse]:
766843
if not response_id.startswith("resp_"):
767844
return self._make_invalid_id_error(response_id)
@@ -771,6 +848,12 @@ async def retrieve_responses(
771848

772849
if response is None:
773850
return self._make_not_found_error(response_id)
851+
852+
if stream:
853+
return self.responses_background_stream_generator(
854+
response_id,
855+
starting_after,
856+
)
774857
return response
775858

776859
async def cancel_responses(

0 commit comments

Comments
 (0)