diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index 68afe784988..1ed21b07bb7 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -1188,7 +1188,7 @@ def test_mcp_path_based_server_segregation(monkeypatch): # Patch the session manager to send a dummy response and capture context async def dummy_handle_request(scope, receive, send): """Dummy handler for testing""" - # Get auth context + # Get auth context (includes client_ip as 7th value) ( user_api_key_auth, mcp_auth_header, @@ -1196,6 +1196,7 @@ async def dummy_handle_request(scope, receive, send): mcp_server_auth_headers, oauth2_headers, raw_headers, + client_ip, ) = get_auth_context() # Capture the MCP servers for testing diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py b/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py index 0e150e064c7..5dbad53948b 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py @@ -11,49 +11,54 @@ class MockGuardrail(CustomGuardrail): """Simple guardrail mock that records invocations.""" - def __init__(self, return_texts=None): + def __init__(self): super().__init__(guardrail_name="mock-mcp-guardrail") - self.return_texts = return_texts self.call_count = 0 self.last_inputs = None + self.last_request_data = None async def apply_guardrail(self, inputs, request_data, input_type, **kwargs): self.call_count += 1 self.last_inputs = inputs - - if self.return_texts is not None: - return {"texts": self.return_texts} - - texts = inputs.get("texts", []) - return {"texts": [f"{text} [SAFE]" for text in texts]} + self.last_request_data = request_data + return None # Guardrail doesn't modify for MCP tools @pytest.mark.asyncio async def test_process_input_messages_updates_content(): - """Handler should update the synthetic message content when guardrail modifies text.""" + """Handler should pass tool definition to guardrail when mcp_tool_name is present.""" handler = MCPGuardrailTranslationHandler() guardrail = MockGuardrail() - original_content = "Tool: weather\nArguments: {'city': 'tokyo'}" data = { - "messages": [{"role": "user", "content": original_content}], "mcp_tool_name": "weather", + "mcp_arguments": {"city": "tokyo"}, + "mcp_tool_description": "Get weather for a city", } result = await handler.process_input_messages(data, guardrail) - assert result["messages"][0]["content"].endswith("[SAFE]") - assert guardrail.last_inputs == {"texts": [original_content]} + # Handler passes data through unchanged + assert result == data + # Guardrail was called assert guardrail.call_count == 1 + # Guardrail received tools (not texts) with tool definition + assert guardrail.last_inputs is not None + tools = guardrail.last_inputs.get("tools", []) + assert len(tools) == 1 + assert tools[0]["function"]["name"] == "weather" + # Request data was passed to guardrail + assert guardrail.last_request_data == data @pytest.mark.asyncio -async def test_process_input_messages_skips_when_no_messages(): - """Handler should skip guardrail invocation if messages array is missing or empty.""" +async def test_process_input_messages_skips_when_no_tool_name(): + """Handler should skip guardrail invocation if mcp_tool_name is missing.""" handler = MCPGuardrailTranslationHandler() guardrail = MockGuardrail() - data = {"mcp_tool_name": "noop"} + # No mcp_tool_name means nothing to process + data = {"some_other_field": "value"} result = await handler.process_input_messages(data, guardrail) assert result == data @@ -61,18 +66,17 @@ async def test_process_input_messages_skips_when_no_messages(): @pytest.mark.asyncio -async def test_process_input_messages_handles_empty_guardrail_result(): - """Handler should leave content untouched when guardrail returns no text updates.""" +async def test_process_input_messages_handles_minimal_data(): + """Handler should work with just mcp_tool_name (minimal required field).""" handler = MCPGuardrailTranslationHandler() - guardrail = MockGuardrail(return_texts=[]) + guardrail = MockGuardrail() - original_content = "Tool: calendar\nArguments: {'date': '2024-12-25'}" - data = { - "messages": [{"role": "user", "content": original_content}], - "mcp_tool_name": "calendar", - } + data = {"mcp_tool_name": "simple_tool"} result = await handler.process_input_messages(data, guardrail) - assert result["messages"][0]["content"] == original_content + assert result == data assert guardrail.call_count == 1 + tools = guardrail.last_inputs.get("tools", []) + assert len(tools) == 1 + assert tools[0]["function"]["name"] == "simple_tool" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py index 4c5723b8284..e4b4d3ba189 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -3,6 +3,23 @@ from unittest.mock import AsyncMock, MagicMock, patch +# Fixture to mock IP address check for all MCP tests +# This prevents tests from failing due to IP-based access control +@pytest.fixture(autouse=True) +def mock_mcp_client_ip(): + """Mock IPAddressUtils.get_mcp_client_ip to return None for all tests. + + This bypasses IP-based access control in tests, since the MCP server's + available_on_public_internet defaults to False and mock requests don't + have proper client IP context. + """ + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.IPAddressUtils.get_mcp_client_ip", + return_value=None, + ): + yield + + @pytest.mark.asyncio async def test_authorize_endpoint_includes_response_type(): """Test that authorize endpoint includes response_type=code parameter (fixes #15684)""" diff --git a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py index 9588c3b55c3..bfeabb6f7ca 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py @@ -118,11 +118,13 @@ def model_dump(self, mode="json", exclude_none=False): mock_a2a_types.SendStreamingMessageRequest = SendStreamingMessageRequest # Patch at the source modules + # Note: add_litellm_data_to_request is called from common_request_processing, + # so we need to patch it there, not at litellm_pre_call_utils with patch( "litellm.proxy.agent_endpoints.a2a_endpoints._get_agent", return_value=mock_agent, ), patch( - "litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", + "litellm.proxy.common_request_processing.add_litellm_data_to_request", side_effect=mock_add_litellm_data, ) as mock_add_data, patch( "litellm.a2a_protocol.create_a2a_client", diff --git a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py index f7e7fcebaef..83914e30354 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py @@ -1130,6 +1130,8 @@ def test_registry_returns_entries_when_enabled(self): mock_manager = MagicMock() mock_manager.get_registry.return_value = {mock_server.server_id: mock_server} + # The registry endpoint uses get_filtered_registry (filters by client IP) + mock_manager.get_filtered_registry.return_value = {mock_server.server_id: mock_server} with patch_proxy_general_settings({"enable_mcp_registry": True}), patch( "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",