Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,63 +16,107 @@ def __init__(self, return_texts=None):
self.return_texts = return_texts
self.call_count = 0
self.last_inputs = None
self.last_request_data = None

async def apply_guardrail(self, inputs, request_data, input_type, **kwargs):
self.call_count += 1
self.last_inputs = inputs
self.last_request_data = request_data

if self.return_texts is not None:
return {"texts": self.return_texts}

texts = inputs.get("texts", [])
return {"texts": [f"{text} [SAFE]" for text in texts]}
# Return original inputs (no modification for tool-based guardrails)
return inputs


@pytest.mark.asyncio
async def test_process_input_messages_updates_content():
"""Handler should update the synthetic message content when guardrail modifies text."""
async def test_process_input_messages_calls_guardrail_with_tool():
"""Handler should call guardrail with tool definition when mcp_tool_name is present."""
handler = MCPGuardrailTranslationHandler()
guardrail = MockGuardrail()

original_content = "Tool: weather\nArguments: {'city': 'tokyo'}"
data = {
"messages": [{"role": "user", "content": original_content}],
"mcp_tool_name": "weather",
"mcp_arguments": {"city": "tokyo"},
"mcp_tool_description": "Get weather for a city",
}

result = await handler.process_input_messages(data, guardrail)

assert result["messages"][0]["content"].endswith("[SAFE]")
assert guardrail.last_inputs == {"texts": [original_content]}
# Guardrail should be called once
assert guardrail.call_count == 1

# Guardrail should receive tool definition in inputs
assert "tools" in guardrail.last_inputs
assert len(guardrail.last_inputs["tools"]) == 1

# ChatCompletionToolParam is a TypedDict (dict), so dict access works.
# Convert to dict explicitly to ensure compatibility with any future changes.
tool = dict(guardrail.last_inputs["tools"][0])
assert tool.get("type") == "function"

# The function is also a TypedDict (ChatCompletionToolParamFunctionChunk)
function = dict(tool.get("function", {}))
assert function.get("name") == "weather"
assert function.get("description") == "Get weather for a city"

# Request data should be passed through
assert guardrail.last_request_data == data

# Result should be the original data (unchanged)
assert result == data


@pytest.mark.asyncio
async def test_process_input_messages_skips_when_no_messages():
"""Handler should skip guardrail invocation if messages array is missing or empty."""
async def test_process_input_messages_skips_when_no_tool_name():
"""Handler should skip guardrail invocation if mcp_tool_name is missing."""
handler = MCPGuardrailTranslationHandler()
guardrail = MockGuardrail()

data = {"mcp_tool_name": "noop"}
# No mcp_tool_name in data - guardrail should not be called
data = {"some_other_field": "value"}
result = await handler.process_input_messages(data, guardrail)

assert result == data
assert guardrail.call_count == 0


@pytest.mark.asyncio
async def test_process_input_messages_handles_empty_guardrail_result():
"""Handler should leave content untouched when guardrail returns no text updates."""
async def test_process_input_messages_handles_name_alias():
"""Handler should accept 'name' as an alias for 'mcp_tool_name'."""
handler = MCPGuardrailTranslationHandler()
guardrail = MockGuardrail()

data = {
"name": "calendar",
"arguments": {"date": "2024-12-25"},
}

result = await handler.process_input_messages(data, guardrail)

assert guardrail.call_count == 1
# Convert to dict for safe access
tool = dict(guardrail.last_inputs["tools"][0])
function = dict(tool.get("function", {}))
assert function.get("name") == "calendar"


@pytest.mark.asyncio
async def test_process_input_messages_handles_missing_arguments():
"""Handler should handle missing mcp_arguments gracefully."""
handler = MCPGuardrailTranslationHandler()
guardrail = MockGuardrail(return_texts=[])
guardrail = MockGuardrail()

original_content = "Tool: calendar\nArguments: {'date': '2024-12-25'}"
data = {
"messages": [{"role": "user", "content": original_content}],
"mcp_tool_name": "calendar",
"mcp_tool_name": "simple_tool",
# No mcp_arguments provided
}

result = await handler.process_input_messages(data, guardrail)

assert result["messages"][0]["content"] == original_content
assert guardrail.call_count == 1
# Convert to dict for safe access
tool = dict(guardrail.last_inputs["tools"][0])
function = dict(tool.get("function", {}))
assert function.get("name") == "simple_tool"
123 changes: 89 additions & 34 deletions tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,59 @@
"""
Mock tests for A2A endpoints.

Tests that invoke_agent_a2a properly integrates with add_litellm_data_to_request.
Tests that invoke_agent_a2a properly integrates with ProxyBaseLLMRequestProcessing
for adding litellm data to requests.
"""

import sys
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch, call

import pytest


@pytest.mark.asyncio
async def test_invoke_agent_a2a_adds_litellm_data():
"""
Test that invoke_agent_a2a calls add_litellm_data_to_request
Test that invoke_agent_a2a calls common_processing_pre_call_logic
and the resulting data includes proxy_server_request.
"""
from litellm.proxy._types import UserAPIKeyAuth

# Track the data passed to add_litellm_data_to_request
captured_data = {}

async def mock_add_litellm_data(data, **kwargs):
# Simulate what add_litellm_data_to_request does
# Track calls to common_processing_pre_call_logic
processing_call_args = {}
returned_data = {}

async def mock_common_processing(
request,
general_settings,
user_api_key_dict,
proxy_logging_obj,
proxy_config,
route_type,
version,
):
# Capture the actual arguments passed to common_processing_pre_call_logic
processing_call_args["request"] = request
processing_call_args["general_settings"] = general_settings
processing_call_args["user_api_key_dict"] = user_api_key_dict
processing_call_args["route_type"] = route_type
processing_call_args["version"] = version

# Get the data from the processor instance
data = mock_processor_instance.data

# Simulate what common_processing_pre_call_logic does
data["proxy_server_request"] = {
"url": "http://localhost:4000/a2a/test-agent",
"method": "POST",
"headers": {},
"body": dict(data),
}
captured_data.update(data)
return data

# Store the returned data to verify endpoint uses it
returned_data.update(data)
mock_logging_obj = MagicMock()
return data, mock_logging_obj
Comment on lines +26 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong route_type asserted

invoke_agent_a2a() passes route_type="asend_message" into common_processing_pre_call_logic (litellm/proxy/agent_endpoints/a2a_endpoints.py:300-308). This test asserts route_type == "a2a_request", so it will fail (or force an incorrect behavior) even though the endpoint is correct. Update the expectation to match the actual route type used by the endpoint.


# Mock response from asend_message
mock_response = MagicMock()
Expand All @@ -40,12 +63,22 @@ async def mock_add_litellm_data(data, **kwargs):
"result": {"status": "success"},
}

# Track what gets passed to asend_message
asend_message_call_args = {}

async def mock_asend_message(*args, **kwargs):
asend_message_call_args["args"] = args
asend_message_call_args["kwargs"] = kwargs
return mock_response

# Mock agent
mock_agent = MagicMock()
mock_agent.agent_card_params = {
"url": "http://backend-agent:10001",
"name": "Test Agent",
}
mock_agent.litellm_params = {}
mock_agent.agent_id = "test-agent-id"

# Mock request
mock_request = MagicMock()
Expand All @@ -71,32 +104,22 @@ async def mock_add_litellm_data(data, **kwargs):
)

# Try to use real a2a.types if available, otherwise create realistic mocks
# This test focuses on LiteLLM integration, not A2A protocol correctness,
# but we want mocks that behave like the real types to catch usage issues
try:
from a2a.types import (
MessageSendParams,
SendMessageRequest,
SendStreamingMessageRequest,
)

# Real types available - use them
pass
except ImportError:
# Real types not available - create realistic mocks
pass

def make_mock_pydantic_class(name):
"""Create a mock class that behaves like a Pydantic model."""

class MockPydanticClass:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
# Store kwargs for model_dump() if needed
self._kwargs = kwargs

def model_dump(self, mode="json", exclude_none=False):
"""Mock model_dump method."""
result = dict(self._kwargs)
if exclude_none:
result = {k: v for k, v in result.items() if v is not None}
Expand All @@ -117,26 +140,43 @@ def model_dump(self, mode="json", exclude_none=False):
mock_a2a_types.SendMessageRequest = SendMessageRequest
mock_a2a_types.SendStreamingMessageRequest = SendStreamingMessageRequest

# Create mock processor instance to capture data
mock_processor_instance = MagicMock()
mock_processor_instance.common_processing_pre_call_logic = AsyncMock(
side_effect=mock_common_processing
)

def mock_processor_init(data):
mock_processor_instance.data = data
return mock_processor_instance

# Patch at the source modules
with patch(
"litellm.proxy.agent_endpoints.a2a_endpoints._get_agent",
return_value=mock_agent,
), patch(
"litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request",
side_effect=mock_add_litellm_data,
) as mock_add_data, patch(
"litellm.proxy.agent_endpoints.a2a_endpoints.AgentRequestHandler.is_agent_allowed",
new_callable=AsyncMock,
return_value=True,
), patch(
"litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing",
side_effect=mock_processor_init,
) as mock_processor_class, patch(
"litellm.a2a_protocol.create_a2a_client",
new_callable=AsyncMock,
), patch(
"litellm.a2a_protocol.asend_message",
new_callable=AsyncMock,
return_value=mock_response,
side_effect=mock_asend_message,
), patch(
"litellm.proxy.proxy_server.general_settings",
{},
), patch(
"litellm.proxy.proxy_server.proxy_config",
MagicMock(),
), patch(
"litellm.proxy.proxy_server.proxy_logging_obj",
MagicMock(),
), patch(
"litellm.proxy.proxy_server.version",
"1.0.0",
Expand All @@ -158,13 +198,28 @@ def model_dump(self, mode="json", exclude_none=False):
user_api_key_dict=mock_user_api_key_dict,
)

# Verify add_litellm_data_to_request was called
mock_add_data.assert_called_once()

# Verify model and custom_llm_provider were set
assert captured_data.get("model") == "a2a_agent/Test Agent"
assert captured_data.get("custom_llm_provider") == "a2a_agent"

# Verify proxy_server_request was added
assert "proxy_server_request" in captured_data
assert captured_data["proxy_server_request"]["method"] == "POST"
# Verify ProxyBaseLLMRequestProcessing was instantiated with data dict
mock_processor_class.assert_called_once()
init_call_args = mock_processor_class.call_args
assert isinstance(init_call_args[0][0], dict), "Processor should be initialized with a dict"

# Verify common_processing_pre_call_logic was called
mock_processor_instance.common_processing_pre_call_logic.assert_called_once()

# Verify the call included correct route_type and version
assert processing_call_args.get("route_type") == "a2a_request"
assert processing_call_args.get("version") == "1.0.0"

# Verify model and custom_llm_provider were set in the data
assert returned_data.get("model") == "a2a_agent/Test Agent"
assert returned_data.get("custom_llm_provider") == "a2a_agent"

# Verify proxy_server_request was added by common_processing_pre_call_logic
assert "proxy_server_request" in returned_data
assert returned_data["proxy_server_request"]["method"] == "POST"

# Verify the data with proxy_server_request is what gets passed downstream
# (The endpoint should use the returned data from common_processing_pre_call_logic)
assert "metadata" in asend_message_call_args.get("kwargs", {}) or \
any("proxy_server_request" in str(arg) for arg in asend_message_call_args.get("args", [])), \
"Data from common_processing_pre_call_logic should be passed to asend_message"
Loading