diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py b/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py index 0e150e064c7..4e2e1fd0f35 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/guardrail_translation/test_mcp_guardrail_handler.py @@ -16,44 +16,66 @@ def __init__(self, return_texts=None): self.return_texts = return_texts self.call_count = 0 self.last_inputs = None + self.last_request_data = None async def apply_guardrail(self, inputs, request_data, input_type, **kwargs): self.call_count += 1 self.last_inputs = inputs + self.last_request_data = request_data if self.return_texts is not None: return {"texts": self.return_texts} - texts = inputs.get("texts", []) - return {"texts": [f"{text} [SAFE]" for text in texts]} + # Return original inputs (no modification for tool-based guardrails) + return inputs @pytest.mark.asyncio -async def test_process_input_messages_updates_content(): - """Handler should update the synthetic message content when guardrail modifies text.""" +async def test_process_input_messages_calls_guardrail_with_tool(): + """Handler should call guardrail with tool definition when mcp_tool_name is present.""" handler = MCPGuardrailTranslationHandler() guardrail = MockGuardrail() - original_content = "Tool: weather\nArguments: {'city': 'tokyo'}" data = { - "messages": [{"role": "user", "content": original_content}], "mcp_tool_name": "weather", + "mcp_arguments": {"city": "tokyo"}, + "mcp_tool_description": "Get weather for a city", } result = await handler.process_input_messages(data, guardrail) - assert result["messages"][0]["content"].endswith("[SAFE]") - assert guardrail.last_inputs == {"texts": [original_content]} + # Guardrail should be called once assert guardrail.call_count == 1 + # Guardrail should receive tool definition in inputs + assert "tools" in guardrail.last_inputs + assert len(guardrail.last_inputs["tools"]) == 1 + + # ChatCompletionToolParam is a TypedDict (dict), so dict access works. + # Convert to dict explicitly to ensure compatibility with any future changes. + tool = dict(guardrail.last_inputs["tools"][0]) + assert tool.get("type") == "function" + + # The function is also a TypedDict (ChatCompletionToolParamFunctionChunk) + function = dict(tool.get("function", {})) + assert function.get("name") == "weather" + assert function.get("description") == "Get weather for a city" + + # Request data should be passed through + assert guardrail.last_request_data == data + + # Result should be the original data (unchanged) + assert result == data + @pytest.mark.asyncio -async def test_process_input_messages_skips_when_no_messages(): - """Handler should skip guardrail invocation if messages array is missing or empty.""" +async def test_process_input_messages_skips_when_no_tool_name(): + """Handler should skip guardrail invocation if mcp_tool_name is missing.""" handler = MCPGuardrailTranslationHandler() guardrail = MockGuardrail() - data = {"mcp_tool_name": "noop"} + # No mcp_tool_name in data - guardrail should not be called + data = {"some_other_field": "value"} result = await handler.process_input_messages(data, guardrail) assert result == data @@ -61,18 +83,40 @@ async def test_process_input_messages_skips_when_no_messages(): @pytest.mark.asyncio -async def test_process_input_messages_handles_empty_guardrail_result(): - """Handler should leave content untouched when guardrail returns no text updates.""" +async def test_process_input_messages_handles_name_alias(): + """Handler should accept 'name' as an alias for 'mcp_tool_name'.""" + handler = MCPGuardrailTranslationHandler() + guardrail = MockGuardrail() + + data = { + "name": "calendar", + "arguments": {"date": "2024-12-25"}, + } + + result = await handler.process_input_messages(data, guardrail) + + assert guardrail.call_count == 1 + # Convert to dict for safe access + tool = dict(guardrail.last_inputs["tools"][0]) + function = dict(tool.get("function", {})) + assert function.get("name") == "calendar" + + +@pytest.mark.asyncio +async def test_process_input_messages_handles_missing_arguments(): + """Handler should handle missing mcp_arguments gracefully.""" handler = MCPGuardrailTranslationHandler() - guardrail = MockGuardrail(return_texts=[]) + guardrail = MockGuardrail() - original_content = "Tool: calendar\nArguments: {'date': '2024-12-25'}" data = { - "messages": [{"role": "user", "content": original_content}], - "mcp_tool_name": "calendar", + "mcp_tool_name": "simple_tool", + # No mcp_arguments provided } result = await handler.process_input_messages(data, guardrail) - assert result["messages"][0]["content"] == original_content assert guardrail.call_count == 1 + # Convert to dict for safe access + tool = dict(guardrail.last_inputs["tools"][0]) + function = dict(tool.get("function", {})) + assert function.get("name") == "simple_tool" diff --git a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py index 9588c3b55c3..b6f8be47220 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py @@ -1,11 +1,12 @@ """ Mock tests for A2A endpoints. -Tests that invoke_agent_a2a properly integrates with add_litellm_data_to_request. +Tests that invoke_agent_a2a properly integrates with ProxyBaseLLMRequestProcessing +for adding litellm data to requests. """ import sys -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch, call import pytest @@ -13,24 +14,46 @@ @pytest.mark.asyncio async def test_invoke_agent_a2a_adds_litellm_data(): """ - Test that invoke_agent_a2a calls add_litellm_data_to_request + Test that invoke_agent_a2a calls common_processing_pre_call_logic and the resulting data includes proxy_server_request. """ from litellm.proxy._types import UserAPIKeyAuth - # Track the data passed to add_litellm_data_to_request - captured_data = {} - - async def mock_add_litellm_data(data, **kwargs): - # Simulate what add_litellm_data_to_request does + # Track calls to common_processing_pre_call_logic + processing_call_args = {} + returned_data = {} + + async def mock_common_processing( + request, + general_settings, + user_api_key_dict, + proxy_logging_obj, + proxy_config, + route_type, + version, + ): + # Capture the actual arguments passed to common_processing_pre_call_logic + processing_call_args["request"] = request + processing_call_args["general_settings"] = general_settings + processing_call_args["user_api_key_dict"] = user_api_key_dict + processing_call_args["route_type"] = route_type + processing_call_args["version"] = version + + # Get the data from the processor instance + data = mock_processor_instance.data + + # Simulate what common_processing_pre_call_logic does data["proxy_server_request"] = { "url": "http://localhost:4000/a2a/test-agent", "method": "POST", "headers": {}, "body": dict(data), } - captured_data.update(data) - return data + + # Store the returned data to verify endpoint uses it + returned_data.update(data) + mock_logging_obj = MagicMock() + return data, mock_logging_obj # Mock response from asend_message mock_response = MagicMock() @@ -40,12 +63,22 @@ async def mock_add_litellm_data(data, **kwargs): "result": {"status": "success"}, } + # Track what gets passed to asend_message + asend_message_call_args = {} + + async def mock_asend_message(*args, **kwargs): + asend_message_call_args["args"] = args + asend_message_call_args["kwargs"] = kwargs + return mock_response + # Mock agent mock_agent = MagicMock() mock_agent.agent_card_params = { "url": "http://backend-agent:10001", "name": "Test Agent", } + mock_agent.litellm_params = {} + mock_agent.agent_id = "test-agent-id" # Mock request mock_request = MagicMock() @@ -71,32 +104,22 @@ async def mock_add_litellm_data(data, **kwargs): ) # Try to use real a2a.types if available, otherwise create realistic mocks - # This test focuses on LiteLLM integration, not A2A protocol correctness, - # but we want mocks that behave like the real types to catch usage issues try: from a2a.types import ( MessageSendParams, SendMessageRequest, SendStreamingMessageRequest, ) - - # Real types available - use them - pass except ImportError: - # Real types not available - create realistic mocks - pass - def make_mock_pydantic_class(name): """Create a mock class that behaves like a Pydantic model.""" class MockPydanticClass: def __init__(self, **kwargs): self.__dict__.update(kwargs) - # Store kwargs for model_dump() if needed self._kwargs = kwargs def model_dump(self, mode="json", exclude_none=False): - """Mock model_dump method.""" result = dict(self._kwargs) if exclude_none: result = {k: v for k, v in result.items() if v is not None} @@ -117,26 +140,43 @@ def model_dump(self, mode="json", exclude_none=False): mock_a2a_types.SendMessageRequest = SendMessageRequest mock_a2a_types.SendStreamingMessageRequest = SendStreamingMessageRequest + # Create mock processor instance to capture data + mock_processor_instance = MagicMock() + mock_processor_instance.common_processing_pre_call_logic = AsyncMock( + side_effect=mock_common_processing + ) + + def mock_processor_init(data): + mock_processor_instance.data = data + return mock_processor_instance + # Patch at the source modules with patch( "litellm.proxy.agent_endpoints.a2a_endpoints._get_agent", return_value=mock_agent, ), patch( - "litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request", - side_effect=mock_add_litellm_data, - ) as mock_add_data, patch( + "litellm.proxy.agent_endpoints.a2a_endpoints.AgentRequestHandler.is_agent_allowed", + new_callable=AsyncMock, + return_value=True, + ), patch( + "litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing", + side_effect=mock_processor_init, + ) as mock_processor_class, patch( "litellm.a2a_protocol.create_a2a_client", new_callable=AsyncMock, ), patch( "litellm.a2a_protocol.asend_message", new_callable=AsyncMock, - return_value=mock_response, + side_effect=mock_asend_message, ), patch( "litellm.proxy.proxy_server.general_settings", {}, ), patch( "litellm.proxy.proxy_server.proxy_config", MagicMock(), + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + MagicMock(), ), patch( "litellm.proxy.proxy_server.version", "1.0.0", @@ -158,13 +198,28 @@ def model_dump(self, mode="json", exclude_none=False): user_api_key_dict=mock_user_api_key_dict, ) - # Verify add_litellm_data_to_request was called - mock_add_data.assert_called_once() - - # Verify model and custom_llm_provider were set - assert captured_data.get("model") == "a2a_agent/Test Agent" - assert captured_data.get("custom_llm_provider") == "a2a_agent" - - # Verify proxy_server_request was added - assert "proxy_server_request" in captured_data - assert captured_data["proxy_server_request"]["method"] == "POST" + # Verify ProxyBaseLLMRequestProcessing was instantiated with data dict + mock_processor_class.assert_called_once() + init_call_args = mock_processor_class.call_args + assert isinstance(init_call_args[0][0], dict), "Processor should be initialized with a dict" + + # Verify common_processing_pre_call_logic was called + mock_processor_instance.common_processing_pre_call_logic.assert_called_once() + + # Verify the call included correct route_type and version + assert processing_call_args.get("route_type") == "a2a_request" + assert processing_call_args.get("version") == "1.0.0" + + # Verify model and custom_llm_provider were set in the data + assert returned_data.get("model") == "a2a_agent/Test Agent" + assert returned_data.get("custom_llm_provider") == "a2a_agent" + + # Verify proxy_server_request was added by common_processing_pre_call_logic + assert "proxy_server_request" in returned_data + assert returned_data["proxy_server_request"]["method"] == "POST" + + # Verify the data with proxy_server_request is what gets passed downstream + # (The endpoint should use the returned data from common_processing_pre_call_logic) + assert "metadata" in asend_message_call_args.get("kwargs", {}) or \ + any("proxy_server_request" in str(arg) for arg in asend_message_call_args.get("args", [])), \ + "Data from common_processing_pre_call_logic should be passed to asend_message"