diff --git a/docs/my-website/docs/proxy/guardrails/onyx_security.md b/docs/my-website/docs/proxy/guardrails/onyx_security.md index 85b0ba9f830..d240902eb52 100644 --- a/docs/my-website/docs/proxy/guardrails/onyx_security.md +++ b/docs/my-website/docs/proxy/guardrails/onyx_security.md @@ -128,6 +128,7 @@ guardrails: mode: ["pre_call", "post_call", "during_call"] # Run at multiple stages api_key: os.environ/ONYX_API_KEY api_base: os.environ/ONYX_API_BASE + timeout: 10.0 # Optional, defaults to 10 seconds ``` ### Required Parameters @@ -137,6 +138,7 @@ guardrails: ### Optional Parameters - **`api_base`**: Onyx API base URL (defaults to `https://ai-guard.onyx.security`) +- **`timeout`**: Request timeout in seconds (defaults to `10.0`) ## Environment Variables @@ -145,4 +147,5 @@ You can set these environment variables instead of hardcoding values in your con ```shell export ONYX_API_KEY="your-api-key-here" export ONYX_API_BASE="https://ai-guard.onyx.security" # Optional +export ONYX_TIMEOUT=10 # Optional, timeout in seconds ``` diff --git a/litellm/proxy/guardrails/guardrail_hooks/onyx/onyx.py b/litellm/proxy/guardrails/guardrail_hooks/onyx/onyx.py index 5f57cab1db4..3598dbe741e 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/onyx/onyx.py +++ b/litellm/proxy/guardrails/guardrail_hooks/onyx/onyx.py @@ -8,6 +8,7 @@ import uuid from typing import TYPE_CHECKING, Any, Literal, Optional, Type +import httpx from fastapi import HTTPException from litellm._logging import verbose_proxy_logger @@ -25,10 +26,12 @@ class OnyxGuardrail(CustomGuardrail): def __init__( - self, api_base: Optional[str] = None, api_key: Optional[str] = None, **kwargs + self, api_base: Optional[str] = None, api_key: Optional[str] = None, timeout: Optional[float] = 10.0, **kwargs ): + timeout = timeout or int(os.getenv("ONYX_TIMEOUT", 10.0)) self.async_handler = get_async_httpx_client( - llm_provider=httpxSpecialProvider.GuardrailCallback + llm_provider=httpxSpecialProvider.GuardrailCallback, + params={"timeout": httpx.Timeout(timeout=timeout, connect=5.0)}, ) self.api_base = api_base or os.getenv( "ONYX_API_BASE", diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/onyx.py b/litellm/types/proxy/guardrails/guardrail_hooks/onyx.py index aa5b9d7a3fc..42d7e94829f 100644 --- a/litellm/types/proxy/guardrails/guardrail_hooks/onyx.py +++ b/litellm/types/proxy/guardrails/guardrail_hooks/onyx.py @@ -16,6 +16,11 @@ class OnyxGuardrailConfigModel(GuardrailConfigModel): description="The API key for the Onyx Guard server. If not provided, the `ONYX_API_KEY` environment variable is checked.", ) + timeout: Optional[float] = Field( + default=None, + description="The timeout for the Onyx Guard server in seconds. If not provided, the `ONYX_TIMEOUT` environment variable is checked.", + ) + @staticmethod def ui_friendly_name() -> str: return "Onyx Guardrail" diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_onyx.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_onyx.py index 9ede649f392..fb7480d263c 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_onyx.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_onyx.py @@ -3,6 +3,7 @@ import uuid from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest from fastapi import HTTPException from httpx import Request, Response @@ -47,20 +48,129 @@ def test_onyx_guard_config(): del os.environ["ONYX_API_KEY"] +def test_onyx_guard_with_custom_timeout_from_kwargs(): + """Test Onyx guard instantiation with custom timeout passed via kwargs.""" + # Set environment variables for testing + os.environ["ONYX_API_BASE"] = "https://test.onyx.security" + os.environ["ONYX_API_KEY"] = "test-api-key" + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + + # Simulate how guardrail is instantiated from config with timeout + guardrail = OnyxGuardrail( + guardrail_name="onyx-guard-custom-timeout", + event_hook="pre_call", + default_on=True, + timeout=45.0, + ) + + # Verify the client was initialized with custom timeout + mock_get_client.assert_called() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 45.0 + assert timeout_param.connect == 5.0 + + # Clean up + if "ONYX_API_BASE" in os.environ: + del os.environ["ONYX_API_BASE"] + if "ONYX_API_KEY" in os.environ: + del os.environ["ONYX_API_KEY"] + + +def test_onyx_guard_with_timeout_none_uses_env_var(): + """Test Onyx guard with timeout=None uses ONYX_TIMEOUT env var. + + When timeout=None is passed (as it would be from config model with default None), + the ONYX_TIMEOUT environment variable should be used. + """ + # Set environment variables for testing + os.environ["ONYX_API_BASE"] = "https://test.onyx.security" + os.environ["ONYX_API_KEY"] = "test-api-key" + os.environ["ONYX_TIMEOUT"] = "60" + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + + # Pass timeout=None to simulate config model behavior + guardrail = OnyxGuardrail( + guardrail_name="onyx-guard-env-timeout", + event_hook="pre_call", + default_on=True, + timeout=None, # This triggers env var lookup + ) + + # Verify the client was initialized with timeout from env var + mock_get_client.assert_called() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 60.0 + assert timeout_param.connect == 5.0 + + # Clean up + if "ONYX_API_BASE" in os.environ: + del os.environ["ONYX_API_BASE"] + if "ONYX_API_KEY" in os.environ: + del os.environ["ONYX_API_KEY"] + if "ONYX_TIMEOUT" in os.environ: + del os.environ["ONYX_TIMEOUT"] + + +def test_onyx_guard_with_timeout_none_defaults_to_10(): + """Test Onyx guard with timeout=None and no env var defaults to 10 seconds.""" + # Set environment variables for testing + os.environ["ONYX_API_BASE"] = "https://test.onyx.security" + os.environ["ONYX_API_KEY"] = "test-api-key" + # Ensure ONYX_TIMEOUT is not set + if "ONYX_TIMEOUT" in os.environ: + del os.environ["ONYX_TIMEOUT"] + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + + # Pass timeout=None with no env var - should default to 10.0 + guardrail = OnyxGuardrail( + guardrail_name="onyx-guard-default-timeout", + event_hook="pre_call", + default_on=True, + timeout=None, + ) + + # Verify the client was initialized with default timeout of 10.0 + mock_get_client.assert_called() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 10.0 + assert timeout_param.connect == 5.0 + + # Clean up + if "ONYX_API_BASE" in os.environ: + del os.environ["ONYX_API_BASE"] + if "ONYX_API_KEY" in os.environ: + del os.environ["ONYX_API_KEY"] + + class TestOnyxGuardrail: """Test suite for Onyx Security Guardrail integration.""" def setup_method(self): """Setup test environment.""" # Clean up any existing environment variables - for key in ["ONYX_API_BASE", "ONYX_API_KEY"]: + for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]: if key in os.environ: del os.environ[key] def teardown_method(self): """Clean up test environment.""" # Clean up any environment variables set during tests - for key in ["ONYX_API_BASE", "ONYX_API_KEY"]: + for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]: if key in os.environ: del os.environ[key] @@ -103,6 +213,95 @@ def test_initialization_fails_when_api_key_missing(self): ): OnyxGuardrail(guardrail_name="test-guard", event_hook="pre_call") + def test_initialization_with_default_timeout(self): + """Test that default timeout is 10.0 seconds.""" + os.environ["ONYX_API_KEY"] = "test-api-key" + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + guardrail = OnyxGuardrail( + guardrail_name="test-guard", event_hook="pre_call", default_on=True + ) + + # Verify the client was initialized with correct timeout + mock_get_client.assert_called_once() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 10.0 + assert timeout_param.connect == 5.0 + + def test_initialization_with_custom_timeout_parameter(self): + """Test initialization with custom timeout parameter.""" + os.environ["ONYX_API_KEY"] = "test-api-key" + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + guardrail = OnyxGuardrail( + guardrail_name="test-guard", + event_hook="pre_call", + default_on=True, + timeout=30.0, + ) + + # Verify the client was initialized with custom timeout + mock_get_client.assert_called_once() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 30.0 + assert timeout_param.connect == 5.0 + + def test_initialization_with_timeout_from_env_var(self): + """Test initialization with timeout from ONYX_TIMEOUT environment variable. + + Note: The env var is only used when timeout=None is explicitly passed, + since the default parameter value is 10.0 (not None). + """ + os.environ["ONYX_API_KEY"] = "test-api-key" + os.environ["ONYX_TIMEOUT"] = "25" + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + # Must pass timeout=None explicitly to trigger env var lookup + guardrail = OnyxGuardrail( + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=None + ) + + # Verify the client was initialized with timeout from env var + mock_get_client.assert_called_once() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 25.0 + assert timeout_param.connect == 5.0 + + def test_initialization_timeout_parameter_overrides_env_var(self): + """Test that timeout parameter overrides ONYX_TIMEOUT environment variable.""" + os.environ["ONYX_API_KEY"] = "test-api-key" + os.environ["ONYX_TIMEOUT"] = "25" + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = MagicMock() + guardrail = OnyxGuardrail( + guardrail_name="test-guard", + event_hook="pre_call", + default_on=True, + timeout=15.0, + ) + + # Verify the client was initialized with parameter timeout (not env var) + mock_get_client.assert_called_once() + call_kwargs = mock_get_client.call_args.kwargs + timeout_param = call_kwargs["params"]["timeout"] + assert timeout_param.read == 15.0 + assert timeout_param.connect == 5.0 + @pytest.mark.asyncio async def test_apply_guardrail_request_no_violations(self): """Test apply_guardrail for request with no violations detected.""" @@ -388,6 +587,105 @@ async def test_apply_guardrail_api_error_handling(self): assert result == inputs + @pytest.mark.asyncio + async def test_apply_guardrail_timeout_error_handling(self): + """Test handling of timeout errors in apply_guardrail (graceful degradation).""" + # Set required API key + os.environ["ONYX_API_KEY"] = "test-api-key" + + guardrail = OnyxGuardrail( + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=1.0 + ) + + inputs = GenericGuardrailAPIInputs() + + request_data = { + "proxy_server_request": { + "messages": [{"role": "user", "content": "Test message"}], + "model": "gpt-3.5-turbo", + } + } + + # Test httpx timeout error + with patch.object( + guardrail.async_handler, "post", side_effect=httpx.TimeoutException("Request timed out") + ): + # Should return original inputs on timeout (graceful degradation) + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + logging_obj=None, + ) + + assert result == inputs + + @pytest.mark.asyncio + async def test_apply_guardrail_read_timeout_error_handling(self): + """Test handling of read timeout errors in apply_guardrail.""" + # Set required API key + os.environ["ONYX_API_KEY"] = "test-api-key" + + guardrail = OnyxGuardrail( + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0 + ) + + inputs = GenericGuardrailAPIInputs() + + request_data = { + "proxy_server_request": { + "messages": [{"role": "user", "content": "Test message"}], + "model": "gpt-3.5-turbo", + } + } + + # Test httpx ReadTimeout error + with patch.object( + guardrail.async_handler, "post", side_effect=httpx.ReadTimeout("Read timed out") + ): + # Should return original inputs on timeout (graceful degradation) + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + logging_obj=None, + ) + + assert result == inputs + + @pytest.mark.asyncio + async def test_apply_guardrail_connect_timeout_error_handling(self): + """Test handling of connect timeout errors in apply_guardrail.""" + # Set required API key + os.environ["ONYX_API_KEY"] = "test-api-key" + + guardrail = OnyxGuardrail( + guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0 + ) + + inputs = GenericGuardrailAPIInputs() + + request_data = { + "proxy_server_request": { + "messages": [{"role": "user", "content": "Test message"}], + "model": "gpt-3.5-turbo", + } + } + + # Test httpx ConnectTimeout error + with patch.object( + guardrail.async_handler, "post", side_effect=httpx.ConnectTimeout("Connect timed out") + ): + # Should return original inputs on timeout (graceful degradation) + result = await guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + logging_obj=None, + ) + + assert result == inputs + @pytest.mark.asyncio async def test_apply_guardrail_no_logging_obj(self): """Test apply_guardrail without logging object (uses UUID)."""