diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 5e171af525..a834a7a13c 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -459,6 +459,14 @@ async def anthropic_proxy_route( region_name=None, ) + custom_headers = {} + if ( + "authorization" not in request.headers + and "x-api-key" not in request.headers + and anthropic_api_key is not None + ): + custom_headers["x-api-key"] = "{}".format(anthropic_api_key) + ## check for streaming is_streaming_request = await is_streaming_request_fn(request) @@ -466,7 +474,7 @@ async def anthropic_proxy_route( endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(updated_url), - custom_headers={"x-api-key": "{}".format(anthropic_api_key)}, + custom_headers=custom_headers, _forward_headers=True, ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_anthropic_auth_headers.py b/tests/test_litellm/proxy/pass_through_endpoints/test_anthropic_auth_headers.py new file mode 100644 index 0000000000..9872ed2c9d --- /dev/null +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_anthropic_auth_headers.py @@ -0,0 +1,218 @@ +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path + +from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + anthropic_proxy_route, +) + + +class TestAnthropicAuthHeaders: + """Test authentication header handling in anthropic_proxy_route.""" + + @pytest.fixture + def mock_request(self): + """Create a mock request object.""" + request = MagicMock() + request.method = "POST" + request.headers = {} + return request + + @pytest.fixture + def mock_response(self): + """Create a mock FastAPI response object.""" + return MagicMock() + + @pytest.fixture + def mock_user_api_key_dict(self): + """Create a mock user API key dict.""" + return {"user_id": "test_user"} + + @pytest.mark.asyncio + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_streaming_request_fn") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") + async def test_client_authorization_header_priority( + self, + mock_router, + mock_streaming, + mock_create_route, + mock_request, + mock_response, + mock_user_api_key_dict, + ): + """Test that client Authorization header takes priority over server key.""" + # Setup + mock_request.headers = {"authorization": "Bearer client-key-123"} + mock_router.get_credentials.return_value = "server-key-456" + mock_streaming.return_value = False + mock_endpoint_func = AsyncMock(return_value="test_response") + mock_create_route.return_value = mock_endpoint_func + + # Act + await anthropic_proxy_route( + endpoint="v1/messages", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert + mock_create_route.assert_called_once() + call_kwargs = mock_create_route.call_args[1] + + assert call_kwargs["custom_headers"] == {} + assert call_kwargs["_forward_headers"] is True + + @pytest.mark.asyncio + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_streaming_request_fn") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") + async def test_client_x_api_key_header_priority( + self, + mock_router, + mock_streaming, + mock_create_route, + mock_request, + mock_response, + mock_user_api_key_dict, + ): + """Test that client x-api-key header takes priority over server key.""" + # Setup + mock_request.headers = {"x-api-key": "client-x-api-key-123"} + mock_router.get_credentials.return_value = "server-key-456" + mock_streaming.return_value = False + mock_endpoint_func = AsyncMock(return_value="test_response") + mock_create_route.return_value = mock_endpoint_func + + # Act + await anthropic_proxy_route( + endpoint="v1/messages", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert + mock_create_route.assert_called_once() + call_kwargs = mock_create_route.call_args[1] + + assert call_kwargs["custom_headers"] == {} + assert call_kwargs["_forward_headers"] is True + + @pytest.mark.asyncio + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_streaming_request_fn") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") + async def test_server_api_key_fallback( + self, + mock_router, + mock_streaming, + mock_create_route, + mock_request, + mock_response, + mock_user_api_key_dict, + ): + """Test that server API key is used when no client authentication is provided.""" + # Setup + mock_request.headers = {} # No authentication headers + mock_router.get_credentials.return_value = "server-key-456" + mock_streaming.return_value = False + mock_endpoint_func = AsyncMock(return_value="test_response") + mock_create_route.return_value = mock_endpoint_func + + # Act + await anthropic_proxy_route( + endpoint="v1/messages", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert + mock_create_route.assert_called_once() + call_kwargs = mock_create_route.call_args[1] + + assert call_kwargs["custom_headers"] == {"x-api-key": "server-key-456"} + assert call_kwargs["_forward_headers"] is True + + @pytest.mark.asyncio + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_streaming_request_fn") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") + async def test_no_authentication_available( + self, + mock_router, + mock_streaming, + mock_create_route, + mock_request, + mock_response, + mock_user_api_key_dict, + ): + """Test that no x-api-key header is added when no authentication is available.""" + # Setup + mock_request.headers = {} # No authentication headers + mock_router.get_credentials.return_value = None # No server key + mock_streaming.return_value = False + mock_endpoint_func = AsyncMock(return_value="test_response") + mock_create_route.return_value = mock_endpoint_func + + # Act + await anthropic_proxy_route( + endpoint="v1/messages", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert + mock_create_route.assert_called_once() + call_kwargs = mock_create_route.call_args[1] + + assert call_kwargs["custom_headers"] == {} + assert call_kwargs["_forward_headers"] is True + + @pytest.mark.asyncio + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_streaming_request_fn") + @patch("litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router") + async def test_both_client_headers_present( + self, + mock_router, + mock_streaming, + mock_create_route, + mock_request, + mock_response, + mock_user_api_key_dict, + ): + """Test that no server key is added when client has both auth headers.""" + # Setup + mock_request.headers = { + "authorization": "Bearer client-auth-key", + "x-api-key": "client-x-api-key" + } + mock_router.get_credentials.return_value = "server-key-456" + mock_streaming.return_value = False + mock_endpoint_func = AsyncMock(return_value="test_response") + mock_create_route.return_value = mock_endpoint_func + + # Act + await anthropic_proxy_route( + endpoint="v1/messages", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Assert + mock_create_route.assert_called_once() + call_kwargs = mock_create_route.call_args[1] + + assert call_kwargs["custom_headers"] == {} + assert call_kwargs["_forward_headers"] is True \ No newline at end of file