Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/my-website/docs/proxy/guardrails/onyx_security.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ guardrails:
mode: ["pre_call", "post_call", "during_call"] # Run at multiple stages
api_key: os.environ/ONYX_API_KEY
api_base: os.environ/ONYX_API_BASE
timeout: 10.0 # Optional, defaults to 10 seconds
```

### Required Parameters
Expand All @@ -137,6 +138,7 @@ guardrails:
### Optional Parameters

- **`api_base`**: Onyx API base URL (defaults to `https://ai-guard.onyx.security`)
- **`timeout`**: Request timeout in seconds (defaults to `10.0`)

## Environment Variables

Expand All @@ -145,4 +147,5 @@ You can set these environment variables instead of hardcoding values in your con
```shell
export ONYX_API_KEY="your-api-key-here"
export ONYX_API_BASE="https://ai-guard.onyx.security" # Optional
export ONYX_TIMEOUT=10 # Optional, timeout in seconds
```
7 changes: 5 additions & 2 deletions litellm/proxy/guardrails/guardrail_hooks/onyx/onyx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
from typing import TYPE_CHECKING, Any, Literal, Optional, Type

import httpx
from fastapi import HTTPException

from litellm._logging import verbose_proxy_logger
Expand All @@ -25,10 +26,12 @@

class OnyxGuardrail(CustomGuardrail):
def __init__(
self, api_base: Optional[str] = None, api_key: Optional[str] = None, **kwargs
self, api_base: Optional[str] = None, api_key: Optional[str] = None, timeout: Optional[float] = 10.0, **kwargs
):
timeout = timeout or int(os.getenv("ONYX_TIMEOUT", 10.0))
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
llm_provider=httpxSpecialProvider.GuardrailCallback,
params={"timeout": httpx.Timeout(timeout=timeout, connect=5.0)},
)
self.api_base = api_base or os.getenv(
"ONYX_API_BASE",
Expand Down
5 changes: 5 additions & 0 deletions litellm/types/proxy/guardrails/guardrail_hooks/onyx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class OnyxGuardrailConfigModel(GuardrailConfigModel):
description="The API key for the Onyx Guard server. If not provided, the `ONYX_API_KEY` environment variable is checked.",
)

timeout: Optional[float] = Field(
default=None,
description="The timeout for the Onyx Guard server in seconds. If not provided, the `ONYX_TIMEOUT` environment variable is checked.",
)

@staticmethod
def ui_friendly_name() -> str:
return "Onyx Guardrail"
302 changes: 300 additions & 2 deletions tests/test_litellm/proxy/guardrails/guardrail_hooks/test_onyx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest
from fastapi import HTTPException
from httpx import Request, Response
Expand Down Expand Up @@ -47,20 +48,129 @@ def test_onyx_guard_config():
del os.environ["ONYX_API_KEY"]


def test_onyx_guard_with_custom_timeout_from_kwargs():
"""Test Onyx guard instantiation with custom timeout passed via kwargs."""
# Set environment variables for testing
os.environ["ONYX_API_BASE"] = "https://test.onyx.security"
os.environ["ONYX_API_KEY"] = "test-api-key"

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()

# Simulate how guardrail is instantiated from config with timeout
guardrail = OnyxGuardrail(
guardrail_name="onyx-guard-custom-timeout",
event_hook="pre_call",
default_on=True,
timeout=45.0,
)

# Verify the client was initialized with custom timeout
mock_get_client.assert_called()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 45.0
assert timeout_param.connect == 5.0

# Clean up
if "ONYX_API_BASE" in os.environ:
del os.environ["ONYX_API_BASE"]
if "ONYX_API_KEY" in os.environ:
del os.environ["ONYX_API_KEY"]


def test_onyx_guard_with_timeout_none_uses_env_var():
"""Test Onyx guard with timeout=None uses ONYX_TIMEOUT env var.

When timeout=None is passed (as it would be from config model with default None),
the ONYX_TIMEOUT environment variable should be used.
"""
# Set environment variables for testing
os.environ["ONYX_API_BASE"] = "https://test.onyx.security"
os.environ["ONYX_API_KEY"] = "test-api-key"
os.environ["ONYX_TIMEOUT"] = "60"

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()

# Pass timeout=None to simulate config model behavior
guardrail = OnyxGuardrail(
guardrail_name="onyx-guard-env-timeout",
event_hook="pre_call",
default_on=True,
timeout=None, # This triggers env var lookup
)

# Verify the client was initialized with timeout from env var
mock_get_client.assert_called()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 60.0
assert timeout_param.connect == 5.0

# Clean up
if "ONYX_API_BASE" in os.environ:
del os.environ["ONYX_API_BASE"]
if "ONYX_API_KEY" in os.environ:
del os.environ["ONYX_API_KEY"]
if "ONYX_TIMEOUT" in os.environ:
del os.environ["ONYX_TIMEOUT"]


def test_onyx_guard_with_timeout_none_defaults_to_10():
"""Test Onyx guard with timeout=None and no env var defaults to 10 seconds."""
# Set environment variables for testing
os.environ["ONYX_API_BASE"] = "https://test.onyx.security"
os.environ["ONYX_API_KEY"] = "test-api-key"
# Ensure ONYX_TIMEOUT is not set
if "ONYX_TIMEOUT" in os.environ:
del os.environ["ONYX_TIMEOUT"]

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()

# Pass timeout=None with no env var - should default to 10.0
guardrail = OnyxGuardrail(
guardrail_name="onyx-guard-default-timeout",
event_hook="pre_call",
default_on=True,
timeout=None,
)

# Verify the client was initialized with default timeout of 10.0
mock_get_client.assert_called()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 10.0
assert timeout_param.connect == 5.0

# Clean up
if "ONYX_API_BASE" in os.environ:
del os.environ["ONYX_API_BASE"]
if "ONYX_API_KEY" in os.environ:
del os.environ["ONYX_API_KEY"]


class TestOnyxGuardrail:
"""Test suite for Onyx Security Guardrail integration."""

def setup_method(self):
"""Setup test environment."""
# Clean up any existing environment variables
for key in ["ONYX_API_BASE", "ONYX_API_KEY"]:
for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]:
if key in os.environ:
del os.environ[key]

def teardown_method(self):
"""Clean up test environment."""
# Clean up any environment variables set during tests
for key in ["ONYX_API_BASE", "ONYX_API_KEY"]:
for key in ["ONYX_API_BASE", "ONYX_API_KEY", "ONYX_TIMEOUT"]:
if key in os.environ:
del os.environ[key]

Expand Down Expand Up @@ -103,6 +213,95 @@ def test_initialization_fails_when_api_key_missing(self):
):
OnyxGuardrail(guardrail_name="test-guard", event_hook="pre_call")

def test_initialization_with_default_timeout(self):
"""Test that default timeout is 10.0 seconds."""
os.environ["ONYX_API_KEY"] = "test-api-key"

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()
guardrail = OnyxGuardrail(
guardrail_name="test-guard", event_hook="pre_call", default_on=True
)

# Verify the client was initialized with correct timeout
mock_get_client.assert_called_once()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 10.0
assert timeout_param.connect == 5.0

def test_initialization_with_custom_timeout_parameter(self):
"""Test initialization with custom timeout parameter."""
os.environ["ONYX_API_KEY"] = "test-api-key"

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()
guardrail = OnyxGuardrail(
guardrail_name="test-guard",
event_hook="pre_call",
default_on=True,
timeout=30.0,
)

# Verify the client was initialized with custom timeout
mock_get_client.assert_called_once()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 30.0
assert timeout_param.connect == 5.0

def test_initialization_with_timeout_from_env_var(self):
"""Test initialization with timeout from ONYX_TIMEOUT environment variable.

Note: The env var is only used when timeout=None is explicitly passed,
since the default parameter value is 10.0 (not None).
"""
os.environ["ONYX_API_KEY"] = "test-api-key"
os.environ["ONYX_TIMEOUT"] = "25"

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()
# Must pass timeout=None explicitly to trigger env var lookup
guardrail = OnyxGuardrail(
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=None
)

# Verify the client was initialized with timeout from env var
mock_get_client.assert_called_once()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 25.0
assert timeout_param.connect == 5.0

def test_initialization_timeout_parameter_overrides_env_var(self):
"""Test that timeout parameter overrides ONYX_TIMEOUT environment variable."""
os.environ["ONYX_API_KEY"] = "test-api-key"
os.environ["ONYX_TIMEOUT"] = "25"

with patch(
"litellm.proxy.guardrails.guardrail_hooks.onyx.onyx.get_async_httpx_client"
) as mock_get_client:
mock_get_client.return_value = MagicMock()
guardrail = OnyxGuardrail(
guardrail_name="test-guard",
event_hook="pre_call",
default_on=True,
timeout=15.0,
)

# Verify the client was initialized with parameter timeout (not env var)
mock_get_client.assert_called_once()
call_kwargs = mock_get_client.call_args.kwargs
timeout_param = call_kwargs["params"]["timeout"]
assert timeout_param.read == 15.0
assert timeout_param.connect == 5.0

@pytest.mark.asyncio
async def test_apply_guardrail_request_no_violations(self):
"""Test apply_guardrail for request with no violations detected."""
Expand Down Expand Up @@ -388,6 +587,105 @@ async def test_apply_guardrail_api_error_handling(self):

assert result == inputs

@pytest.mark.asyncio
async def test_apply_guardrail_timeout_error_handling(self):
"""Test handling of timeout errors in apply_guardrail (graceful degradation)."""
# Set required API key
os.environ["ONYX_API_KEY"] = "test-api-key"

guardrail = OnyxGuardrail(
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=1.0
)

inputs = GenericGuardrailAPIInputs()

request_data = {
"proxy_server_request": {
"messages": [{"role": "user", "content": "Test message"}],
"model": "gpt-3.5-turbo",
}
}

# Test httpx timeout error
with patch.object(
guardrail.async_handler, "post", side_effect=httpx.TimeoutException("Request timed out")
):
# Should return original inputs on timeout (graceful degradation)
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
logging_obj=None,
)

assert result == inputs

@pytest.mark.asyncio
async def test_apply_guardrail_read_timeout_error_handling(self):
"""Test handling of read timeout errors in apply_guardrail."""
# Set required API key
os.environ["ONYX_API_KEY"] = "test-api-key"

guardrail = OnyxGuardrail(
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0
)

inputs = GenericGuardrailAPIInputs()

request_data = {
"proxy_server_request": {
"messages": [{"role": "user", "content": "Test message"}],
"model": "gpt-3.5-turbo",
}
}

# Test httpx ReadTimeout error
with patch.object(
guardrail.async_handler, "post", side_effect=httpx.ReadTimeout("Read timed out")
):
# Should return original inputs on timeout (graceful degradation)
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
logging_obj=None,
)

assert result == inputs

@pytest.mark.asyncio
async def test_apply_guardrail_connect_timeout_error_handling(self):
"""Test handling of connect timeout errors in apply_guardrail."""
# Set required API key
os.environ["ONYX_API_KEY"] = "test-api-key"

guardrail = OnyxGuardrail(
guardrail_name="test-guard", event_hook="pre_call", default_on=True, timeout=5.0
)

inputs = GenericGuardrailAPIInputs()

request_data = {
"proxy_server_request": {
"messages": [{"role": "user", "content": "Test message"}],
"model": "gpt-3.5-turbo",
}
}

# Test httpx ConnectTimeout error
with patch.object(
guardrail.async_handler, "post", side_effect=httpx.ConnectTimeout("Connect timed out")
):
# Should return original inputs on timeout (graceful degradation)
result = await guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
logging_obj=None,
)

assert result == inputs

@pytest.mark.asyncio
async def test_apply_guardrail_no_logging_obj(self):
"""Test apply_guardrail without logging object (uses UUID)."""
Expand Down
Loading