Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/utils/mcp_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
logger = logging.getLogger("app.endpoints.dependencies")


async def mcp_headers_dependency(_request: Request) -> dict[str, dict[str, str]]:
"""Get the mcp headers dependency to passed to mcp servers.
async def mcp_headers_dependency(request: Request) -> dict[str, dict[str, str]]:
"""Get the MCP headers dependency to passed to mcp servers.

mcp headers is a json dictionary or mcp url paths and their respective headers

Expand All @@ -23,7 +23,7 @@ async def mcp_headers_dependency(_request: Request) -> dict[str, dict[str, str]]
Returns:
The mcp headers dictionary, or empty dictionary if not found or on json decoding error
"""
return extract_mcp_headers(_request)
return extract_mcp_headers(request)


def extract_mcp_headers(request: Request) -> dict[str, dict[str, str]]:
Expand Down
184 changes: 184 additions & 0 deletions tests/unit/utils/test_mcp_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Unit tests for MCP headers utility functions."""

from unittest.mock import Mock
import pytest

from fastapi import Request

from utils import mcp_headers


def test_extract_mcp_headers_empty_headers():
"""Test the extract_mcp_headers function for request without any headers."""
request = Mock(spec=Request)
# no headers
request.headers = {}

result = mcp_headers.extract_mcp_headers(request)
assert result == {}


def test_extract_mcp_headers_mcp_headers_empty():
"""Test the extract_mcp_headers function for request with empty MCP-HEADERS header."""
request = Mock(spec=Request)
# empty MCP-HEADERS
request.headers = {"MCP-HEADERS": ""}

# empty dict should be returned
result = mcp_headers.extract_mcp_headers(request)
assert result == {}


def test_extract_mcp_headers_valid_mcp_header():
"""Test the extract_mcp_headers function for request with valid MCP-HEADERS header."""
request = Mock(spec=Request)
# valid MCP-HEADERS
request.headers = {"MCP-HEADERS": '{"http://www.redhat.com": {"auth": "token123"}}'}

result = mcp_headers.extract_mcp_headers(request)

expected = {"http://www.redhat.com": {"auth": "token123"}}
assert result == expected


def test_extract_mcp_headers_valid_mcp_headers():
"""Test the extract_mcp_headers function for request with valid MCP-HEADERS headers."""
request = Mock(spec=Request)
# valid MCP-HEADERS
header1 = '"http://www.redhat.com": {"auth": "token123"}'
header2 = '"http://www.example.com": {"auth": "tokenXYZ"}'

request.headers = {"MCP-HEADERS": f"{{{header1}, {header2}}}"}

result = mcp_headers.extract_mcp_headers(request)

expected = {
"http://www.redhat.com": {"auth": "token123"},
"http://www.example.com": {"auth": "tokenXYZ"},
}
assert result == expected


def test_extract_mcp_headers_invalid_json_mcp_header():
"""Test the extract_mcp_headers function for request with invalid MCP-HEADERS header."""
request = Mock(spec=Request)
# invalid MCP-HEADERS - not a JSON
request.headers = {"MCP-HEADERS": "this-is-invalid"}

# empty dict should be returned
result = mcp_headers.extract_mcp_headers(request)
assert result == {}


def test_extract_mcp_headers_invalid_mcp_header_type():
"""Test the extract_mcp_headers function for request with invalid MCP-HEADERS header type."""
request = Mock(spec=Request)
# invalid MCP-HEADERS - not a dict
request.headers = {"MCP-HEADERS": "[]"}

# empty dict should be returned
result = mcp_headers.extract_mcp_headers(request)
assert result == {}


def test_extract_mcp_headers_invalid_mcp_header_null_value():
"""Test the extract_mcp_headers function for request with invalid MCP-HEADERS header type."""
request = Mock(spec=Request)
# invalid MCP-HEADERS - not a dict
request.headers = {"MCP-HEADERS": "null"}

# empty dict should be returned
result = mcp_headers.extract_mcp_headers(request)
assert result == {}


@pytest.mark.asyncio
async def test_mcp_headers_dependency_empty_headers():
"""Test the mcp_headers_dependency function for request with empty MCP-HEADERS header."""
request = Mock(spec=Request)
# empty MCP-HEADERS
request.headers = {"MCP-HEADERS": ""}

# empty dict should be returned
result = await mcp_headers.mcp_headers_dependency(request)
assert result == {}


@pytest.mark.asyncio
async def test_mcp_headers_dependency_mcp_headers_empty():
"""Test the mcp_headers_dependency function for request with empty MCP-HEADERS header."""
request = Mock(spec=Request)
# empty MCP-HEADERS
request.headers = {"MCP-HEADERS": ""}

# empty dict should be returned
result = await mcp_headers.mcp_headers_dependency(request)
assert result == {}


@pytest.mark.asyncio
async def test_mcp_headers_dependency_valid_mcp_header():
"""Test the mcp_headers_dependency function for request with valid MCP-HEADERS header."""
request = Mock(spec=Request)
# valid MCP-HEADERS
request.headers = {"MCP-HEADERS": '{"http://www.redhat.com": {"auth": "token123"}}'}

result = await mcp_headers.mcp_headers_dependency(request)

expected = {"http://www.redhat.com": {"auth": "token123"}}
assert result == expected


@pytest.mark.asyncio
async def test_mcp_headers_dependency_valid_mcp_headers():
"""Test the mcp_headers_dependency function for request with valid MCP-HEADERS headers."""
request = Mock(spec=Request)
# valid MCP-HEADERS
header1 = '"http://www.redhat.com": {"auth": "token123"}'
header2 = '"http://www.example.com": {"auth": "tokenXYZ"}'

request.headers = {"MCP-HEADERS": f"{{{header1}, {header2}}}"}

result = await mcp_headers.mcp_headers_dependency(request)

expected = {
"http://www.redhat.com": {"auth": "token123"},
"http://www.example.com": {"auth": "tokenXYZ"},
}
assert result == expected


@pytest.mark.asyncio
async def test_mcp_headers_dependency_invalid_json_mcp_header():
"""Test the mcp_headers_dependency function for request with invalid MCP-HEADERS header."""
request = Mock(spec=Request)
# invalid MCP-HEADERS - not a JSON
request.headers = {"MCP-HEADERS": "this-is-invalid"}

# empty dict should be returned
result = await mcp_headers.mcp_headers_dependency(request)
assert result == {}


@pytest.mark.asyncio
async def test_mcp_headers_dependency_invalid_mcp_header_type():
"""Test the mcp_headers_dependency function for request with invalid MCP-HEADERS header type."""
request = Mock(spec=Request)
# invalid MCP-HEADERS - not a dict
request.headers = {"MCP-HEADERS": "[]"}

# empty dict should be returned
result = await mcp_headers.mcp_headers_dependency(request)
assert result == {}


@pytest.mark.asyncio
async def test_mcp_headers_dependency_invalid_mcp_header_null_value():
"""Test the mcp_headers_dependency function for request with invalid MCP-HEADERS header type."""
request = Mock(spec=Request)
# invalid MCP-HEADERS - not a dict
request.headers = {"MCP-HEADERS": "null"}

# empty dict should be returned
result = await mcp_headers.mcp_headers_dependency(request)
assert result == {}