diff --git a/litellm/main.py b/litellm/main.py index c1c4efd943..969cf55a3d 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -28,6 +28,7 @@ Callable, Coroutine, Dict, + Iterable, List, Literal, Mapping, @@ -1094,23 +1095,68 @@ def completion( # type: ignore # noqa: PLR0915 # validate tool_choice tool_choice = validate_chat_completion_tool_choice(tool_choice=tool_choice) + ######### unpacking kwargs ##################### + args = locals() + skip_mcp_handler = kwargs.pop("_skip_mcp_handler", False) if not skip_mcp_handler and tools: from litellm.responses.mcp.chat_completions_handler import ( - handle_chat_completion_with_mcp, + acompletion_with_mcp, ) - - mcp_handler_context = locals().copy() - completion_callable = globals().get("acompletion") - mcp_result = run_async_function( - handle_chat_completion_with_mcp, - mcp_handler_context, - completion_callable, + from litellm.responses.mcp.litellm_proxy_mcp_handler import ( + LiteLLM_Proxy_MCP_Handler, ) - if mcp_result is not None: - return mcp_result - ######### unpacking kwargs ##################### - args = locals() + from litellm.types.llms.openai import ToolParam + + # Check if MCP tools are present (following responses pattern) + # Cast tools to Optional[Iterable[ToolParam]] for type checking + tools_for_mcp = cast(Optional[Iterable[ToolParam]], tools) + if LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(tools=tools_for_mcp): + # Return coroutine - acompletion will await it + # completion() can return a coroutine when MCP tools are present, which acompletion() awaits + return acompletion_with_mcp( # type: ignore[return-value] + model=model, + messages=messages, + functions=functions, + function_call=function_call, + timeout=timeout, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stream_options=stream_options, + stop=stop, + max_tokens=max_tokens, + max_completion_tokens=max_completion_tokens, + modalities=modalities, + prediction=prediction, + audio=audio, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + response_format=response_format, + seed=seed, + tools=tools, + tool_choice=tool_choice, + parallel_tool_calls=parallel_tool_calls, + logprobs=logprobs, + top_logprobs=top_logprobs, + deployment_id=deployment_id, + reasoning_effort=reasoning_effort, + verbosity=verbosity, + safety_identifier=safety_identifier, + service_tier=service_tier, + base_url=base_url, + api_version=api_version, + api_key=api_key, + model_list=model_list, + extra_headers=extra_headers, + thinking=thinking, + web_search_options=web_search_options, + shared_session=shared_session, + **kwargs, + ) api_base = kwargs.get("api_base", None) mock_response: Optional[MOCK_RESPONSE_TYPE] = kwargs.get("mock_response", None) mock_tool_calls = kwargs.get("mock_tool_calls", None) diff --git a/litellm/responses/mcp/chat_completions_handler.py b/litellm/responses/mcp/chat_completions_handler.py index 1957e5fa92..6ce59e3e67 100644 --- a/litellm/responses/mcp/chat_completions_handler.py +++ b/litellm/responses/mcp/chat_completions_handler.py @@ -2,127 +2,67 @@ from typing import ( Any, - Awaitable, - Callable, - Dict, - Iterable, + List, Optional, Union, - cast, ) from litellm.responses.mcp.litellm_proxy_mcp_handler import ( LiteLLM_Proxy_MCP_Handler, ) from litellm.responses.utils import ResponsesAPIRequestUtils -from litellm.types.llms.openai import ToolParam from litellm.types.utils import ModelResponse from litellm.utils import CustomStreamWrapper -CompletionCallable = Callable[..., Awaitable[Union[ModelResponse, CustomStreamWrapper]]] - -_CHAT_COMPLETION_CALL_ARG_KEYS = [ - "model", - "messages", - "functions", - "function_call", - "timeout", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "max_completion_tokens", - "modalities", - "prediction", - "audio", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "parallel_tool_calls", - "logprobs", - "top_logprobs", - "deployment_id", - "reasoning_effort", - "verbosity", - "safety_identifier", - "service_tier", - "base_url", - "api_version", - "api_key", - "model_list", - "extra_headers", - "thinking", - "web_search_options", - "shared_session", -] - - -def _build_call_args_from_context(call_context: Dict[str, Any]) -> Dict[str, Any]: - """Build kwargs for `acompletion` from the `completion` call context.""" - - call_args = { - key: call_context.get(key) - for key in _CHAT_COMPLETION_CALL_ARG_KEYS - if key in call_context - } - additional_kwargs = dict(call_context.get("kwargs") or {}) - call_args.update(additional_kwargs) - return call_args - -async def _call_acompletion_internal( - completion_callable: CompletionCallable, **call_args: Any +async def acompletion_with_mcp( + model: str, + messages: List, + tools: Optional[List] = None, + **kwargs: Any, ) -> Union[ModelResponse, CustomStreamWrapper]: - """Invoke `acompletion` while skipping MCP interception to avoid recursion.""" - - safe_args = dict(call_args) - safe_args["_skip_mcp_handler"] = True - safe_args.pop("acompletion", None) - return await completion_callable(**safe_args) - - -async def handle_chat_completion_with_mcp( - call_context: Dict[str, Any], - completion_callable: CompletionCallable, -) -> Optional[Union[ModelResponse, CustomStreamWrapper]]: - """Handle MCP-enabled tool execution for chat completion requests.""" - - call_args = _build_call_args_from_context(call_context) - - tools = call_args.get("tools") - if not tools: - return None - - tools_for_mcp = cast(Optional[Iterable[ToolParam]], tools) - - if not LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway( - tools=tools_for_mcp - ): - return None - - mcp_tools, _ = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools) - if not mcp_tools: - return None - - base_call_args = dict(call_args) - - user_api_key_auth = call_args.get("user_api_key_auth") or ( - (call_args.get("metadata", {}) or {}).get("user_api_key_auth") + """ + Async completion with MCP integration. + + This function handles MCP tool integration following the same pattern as aresponses_api_with_mcp. + It's designed to be called from the synchronous completion() function and return a coroutine. + + When MCP tools with server_url="litellm_proxy" are provided, this function will: + 1. Get available tools from the MCP server manager + 2. Transform them to OpenAI format + 3. Call acompletion with the transformed tools + 4. If require_approval="never" and tool calls are returned, automatically execute them + 5. Make a follow-up call with the tool results + """ + from litellm import acompletion as litellm_acompletion + + # Parse MCP tools and separate from other tools + ( + mcp_tools_with_litellm_proxy, + other_tools, + ) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools) + + if not mcp_tools_with_litellm_proxy: + # No MCP tools, proceed with regular completion + return await litellm_acompletion( + model=model, + messages=messages, + tools=tools, + **kwargs, + ) + + # Extract user_api_key_auth from metadata or kwargs + user_api_key_auth = kwargs.get("user_api_key_auth") or ( + (kwargs.get("metadata", {}) or {}).get("user_api_key_auth") ) + + # Process MCP tools ( deduplicated_mcp_tools, tool_server_map, ) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform( user_api_key_auth=user_api_key_auth, - mcp_tools_with_litellm_proxy=mcp_tools, + mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, ) openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai( @@ -130,25 +70,43 @@ async def handle_chat_completion_with_mcp( target_format="chat", ) - base_call_args["tools"] = openai_tools or None + # Combine with other tools + all_tools = openai_tools + other_tools if (openai_tools or other_tools) else None + # Determine if we should auto-execute tools should_auto_execute = LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools( - mcp_tools_with_litellm_proxy=mcp_tools + mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy ) + # Extract MCP auth headers ( mcp_auth_header, mcp_server_auth_headers, oauth2_headers, raw_headers, ) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request( - secret_fields=base_call_args.get("secret_fields"), + secret_fields=kwargs.get("secret_fields"), tools=tools, ) + # Prepare call parameters + # Remove keys that shouldn't be passed to acompletion + clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]} + + base_call_args = { + "model": model, + "messages": messages, + "tools": all_tools, + "_skip_mcp_handler": True, # Prevent recursion + **clean_kwargs, + } + + # If not auto-executing, just make the call with transformed tools if not should_auto_execute: - return await _call_acompletion_internal(completion_callable, **base_call_args) + return await litellm_acompletion(**base_call_args) + # For auto-execute: disable streaming for initial call + stream = kwargs.get("stream", False) mock_tool_calls = base_call_args.pop("mock_tool_calls", None) initial_call_args = dict(base_call_args) @@ -156,23 +114,26 @@ async def handle_chat_completion_with_mcp( if mock_tool_calls is not None: initial_call_args["mock_tool_calls"] = mock_tool_calls - initial_response = await _call_acompletion_internal( - completion_callable, **initial_call_args - ) + # Make initial call + initial_response = await litellm_acompletion(**initial_call_args) + if not isinstance(initial_response, ModelResponse): return initial_response + # Extract tool calls from response tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response( response=initial_response ) if not tool_calls: - if base_call_args.get("stream"): + # No tool calls, return response or retry with streaming if needed + if stream: retry_args = dict(base_call_args) - retry_args["stream"] = call_args.get("stream") - return await _call_acompletion_internal(completion_callable, **retry_args) + retry_args["stream"] = stream + return await litellm_acompletion(**retry_args) return initial_response + # Execute tool calls tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls( tool_server_map=tool_server_map, tool_calls=tool_calls, @@ -186,14 +147,16 @@ async def handle_chat_completion_with_mcp( if not tool_results: return initial_response + # Create follow-up messages with tool results follow_up_messages = LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat( - original_messages=call_args.get("messages", []), + original_messages=messages, response=initial_response, tool_results=tool_results, ) + # Make follow-up call with original stream setting follow_up_call_args = dict(base_call_args) follow_up_call_args["messages"] = follow_up_messages - follow_up_call_args["stream"] = call_args.get("stream") + follow_up_call_args["stream"] = stream - return await _call_acompletion_internal(completion_callable, **follow_up_call_args) + return await litellm_acompletion(**follow_up_call_args) diff --git a/tests/mcp_tests/test_mcp_chat_completions.py b/tests/mcp_tests/test_mcp_chat_completions.py index ae13b6ca6e..973301abfb 100644 --- a/tests/mcp_tests/test_mcp_chat_completions.py +++ b/tests/mcp_tests/test_mcp_chat_completions.py @@ -141,3 +141,174 @@ async def fake_execute(**kwargs): assert isinstance(response, ModelResponse) tool_calls = response.choices[0].message.tool_calls assert tool_calls is not None and len(tool_calls) == 1 + + +@pytest.mark.asyncio +async def test_completion_mcp_with_streaming_no_timeout_error(monkeypatch): + """ + Test that litellm.completion with stream=True and MCP tools does not raise + RuntimeError: Timeout context manager should be used inside a task. + + This test ensures that the fix in ba43f742ab86d51b7da63077b85b39d0ac808d30 + prevents event loop nesting issues when using MCP tools with streaming. + + The fix changes completion() to return a coroutine from acompletion_with_mcp, + which acompletion() then awaits, avoiding event loop nesting. + """ + from types import SimpleNamespace + from unittest.mock import patch + + from litellm.responses.mcp.litellm_proxy_mcp_handler import ( + LiteLLM_Proxy_MCP_Handler, + ) + from litellm.responses.utils import ResponsesAPIRequestUtils + from litellm.utils import CustomStreamWrapper + + dummy_tool = SimpleNamespace( + name="local_search", + description="search", + inputSchema={"type": "object", "properties": {}}, + ) + + async def fake_process(user_api_key_auth, mcp_tools_with_litellm_proxy): + return [dummy_tool], {"local_search": "local"} + + async def fake_execute(**kwargs): + fake_execute.called = True # type: ignore[attr-defined] + tool_calls = kwargs.get("tool_calls") or [] + assert tool_calls, "tool calls should be present during auto execution" + call_entry = tool_calls[0] + call_id = call_entry.get("id") or call_entry.get("call_id") or "call" + return [ + { + "tool_call_id": call_id, + "result": "executed", + "name": call_entry.get("name", "local_search"), + } + ] + + fake_execute.called = False # type: ignore[attr-defined] + + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_process_mcp_tools_without_openai_transform", + fake_process, + ) + monkeypatch.setattr( + LiteLLM_Proxy_MCP_Handler, + "_execute_tool_calls", + fake_execute, + ) + monkeypatch.setattr( + ResponsesAPIRequestUtils, + "extract_mcp_headers_from_request", + staticmethod(lambda secret_fields, tools: (None, None, None, None)), + ) + + # Create a mock streaming response + class MockStreamingResponse(CustomStreamWrapper): + def __init__(self): + self.chunks = [ + type('Chunk', (), { + 'choices': [type('Choice', (), { + 'delta': type('Delta', (), { + 'content': 'Final' + })() + })()] + })(), + type('Chunk', (), { + 'choices': [type('Choice', (), { + 'delta': type('Delta', (), { + 'content': ' answer' + })() + })()] + })(), + ] + self._index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self._index < len(self.chunks): + chunk = self.chunks[self._index] + self._index += 1 + return chunk + raise StopIteration + + # Track calls to acompletion + acompletion_calls = [] + + async def mock_acompletion(**kwargs): + acompletion_calls.append(kwargs) + # First call (non-streaming for tool extraction) + if not kwargs.get("stream", False): + # Return a ModelResponse with tool_calls using dict format + return ModelResponse( + id="test-1", + model="gpt-4o-mini", + choices=[{ + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call-1", + "type": "function", + "function": { + "name": "local_search", + "arguments": "{}" + } + }] + }, + "finish_reason": "tool_calls" + }], + created=0, + object="chat.completion", + ) + # Second call (streaming follow-up) + return MockStreamingResponse() + + with patch("litellm.acompletion", side_effect=mock_acompletion): + # This should not raise RuntimeError: Timeout context manager should be used inside a task + # completion() returns a coroutine when MCP tools are present, which acompletion() awaits + response = litellm.completion( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hello"}], + tools=[ + { + "type": "mcp", + "server_url": "litellm_proxy/mcp/local", + "server_label": "local", + "require_approval": "never", + } + ], + stream=True, + mock_response="Final answer", + mock_tool_calls=[ + { + "id": "call-1", + "type": "function", + "function": {"name": "local_search", "arguments": "{}"}, + } + ], + ) + + # completion() returns a coroutine when MCP tools are present + import asyncio + assert asyncio.iscoroutine(response), "completion() should return a coroutine when MCP tools are present" + + # Await the coroutine (this is what acompletion() does internally) + # This should not raise RuntimeError: Timeout context manager should be used inside a task + result = await response + + # Verify response is a streaming response + assert isinstance(result, CustomStreamWrapper) or hasattr(result, '__iter__') + + # Consume the stream to ensure it works + chunks = list(result) + assert len(chunks) > 0, "Should have received streaming chunks" + + # Verify tool execution was called + assert fake_execute.called is True # type: ignore[attr-defined] + + # Verify acompletion was called (should be called by acompletion_with_mcp) + assert len(acompletion_calls) >= 1, "acompletion should be called" diff --git a/tests/test_litellm/responses/mcp/test_chat_completions_handler.py b/tests/test_litellm/responses/mcp/test_chat_completions_handler.py index 96e7c39aee..03a749a808 100644 --- a/tests/test_litellm/responses/mcp/test_chat_completions_handler.py +++ b/tests/test_litellm/responses/mcp/test_chat_completions_handler.py @@ -1,10 +1,10 @@ import pytest -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch from litellm.types.utils import ModelResponse from litellm.responses.mcp.chat_completions_handler import ( - handle_chat_completion_with_mcp, + acompletion_with_mcp, ) from litellm.responses.mcp.litellm_proxy_mcp_handler import ( LiteLLM_Proxy_MCP_Handler, @@ -13,19 +13,24 @@ @pytest.mark.asyncio -async def test_handle_chat_completion_returns_none_without_tools(): - completion_callable = AsyncMock() +async def test_acompletion_with_mcp_returns_normal_completion_without_tools(monkeypatch): + mock_acompletion = AsyncMock(return_value="normal_response") - result = await handle_chat_completion_with_mcp({}, completion_callable) + with patch("litellm.acompletion", mock_acompletion): + result = await acompletion_with_mcp( + model="test-model", + messages=[], + tools=None, + ) - assert result is None - completion_callable.assert_not_awaited() + assert result == "normal_response" + mock_acompletion.assert_awaited_once() @pytest.mark.asyncio -async def test_handle_chat_completion_without_auto_execution_calls_model(monkeypatch): +async def test_acompletion_with_mcp_without_auto_execution_calls_model(monkeypatch): tools = [{"type": "function", "function": {"name": "tool"}}] - completion_callable = AsyncMock(return_value="ok") + mock_acompletion = AsyncMock(return_value="ok") monkeypatch.setattr( LiteLLM_Proxy_MCP_Handler, @@ -35,7 +40,7 @@ async def test_handle_chat_completion_without_auto_execution_calls_model(monkeyp monkeypatch.setattr( LiteLLM_Proxy_MCP_Handler, "_parse_mcp_tools", - staticmethod(lambda tools: (tools, {})), + staticmethod(lambda tools: (tools, [])), ) async def mock_process(**_): return ([], {}) @@ -67,23 +72,25 @@ def mock_extract(**kwargs): staticmethod(mock_extract), ) - call_context = { - "tools": tools, - "messages": [], - "kwargs": {"secret_fields": {"api_key": "value"}}, - } - result = await handle_chat_completion_with_mcp(call_context, completion_callable) + with patch("litellm.acompletion", mock_acompletion): + result = await acompletion_with_mcp( + model="test-model", + messages=[], + tools=tools, + secret_fields={"api_key": "value"}, + ) assert result == "ok" - completion_callable.assert_awaited_once() - kwargs = completion_callable.await_args.kwargs + mock_acompletion.assert_awaited_once() + assert mock_acompletion.await_args is not None + kwargs = mock_acompletion.await_args.kwargs assert kwargs.get("_skip_mcp_handler") is True assert kwargs.get("tools") == ["openai-tool"] assert captured_secret_fields["value"] == {"api_key": "value"} @pytest.mark.asyncio -async def test_handle_chat_completion_auto_exec_performs_follow_up(monkeypatch): +async def test_acompletion_with_mcp_auto_exec_performs_follow_up(monkeypatch): tools = [{"type": "function", "function": {"name": "tool"}}] initial_response = ModelResponse( id="1", @@ -99,7 +106,7 @@ async def test_handle_chat_completion_auto_exec_performs_follow_up(monkeypatch): created=0, object="chat.completion", ) - completion_callable = AsyncMock( + mock_acompletion = AsyncMock( side_effect=[initial_response, follow_up_response] ) @@ -111,7 +118,7 @@ async def test_handle_chat_completion_auto_exec_performs_follow_up(monkeypatch): monkeypatch.setattr( LiteLLM_Proxy_MCP_Handler, "_parse_mcp_tools", - staticmethod(lambda tools: (tools, {"tool": "server"})), + staticmethod(lambda tools: (tools, [])), ) async def mock_process(**_): return (tools, {"tool": "server"}) @@ -155,13 +162,18 @@ async def mock_execute(**_): staticmethod(lambda **_: (None, None, None, None)), ) - call_context = {"tools": tools, "messages": ["msg"], "stream": True} - result = await handle_chat_completion_with_mcp(call_context, completion_callable) + with patch("litellm.acompletion", mock_acompletion): + result = await acompletion_with_mcp( + model="test-model", + messages=["msg"], + tools=tools, + stream=True, + ) assert result is follow_up_response - assert completion_callable.await_count == 2 - first_call = completion_callable.await_args_list[0].kwargs - second_call = completion_callable.await_args_list[1].kwargs + assert mock_acompletion.await_count == 2 + first_call = mock_acompletion.await_args_list[0].kwargs + second_call = mock_acompletion.await_args_list[1].kwargs assert first_call["stream"] is False assert second_call["messages"] == ["follow-up"] assert second_call["stream"] is True