diff --git a/src/fastmcp/server/middleware/error_handling.py b/src/fastmcp/server/middleware/error_handling.py index 7cb730d905..fc0e2a1c80 100644 --- a/src/fastmcp/server/middleware/error_handling.py +++ b/src/fastmcp/server/middleware/error_handling.py @@ -87,7 +87,7 @@ def _transform_error(self, error: Exception) -> Exception: return error # Map common exceptions to appropriate MCP error codes - error_type = type(error) + error_type = type(error.__cause__) if error.__cause__ else type(error) if error_type in (ValueError, TypeError): return McpError( diff --git a/tests/server/middleware/test_error_handling.py b/tests/server/middleware/test_error_handling.py index 1ed534709f..7da8c31192 100644 --- a/tests/server/middleware/test_error_handling.py +++ b/tests/server/middleware/test_error_handling.py @@ -6,7 +6,7 @@ import pytest from mcp import McpError -from fastmcp.exceptions import NotFoundError +from fastmcp.exceptions import NotFoundError, ToolError from fastmcp.server.middleware.error_handling import ( ErrorHandlingMiddleware, RetryMiddleware, @@ -215,6 +215,23 @@ async def test_on_message_error_transform(self, mock_context, caplog): assert "Invalid params: test error" in exc_info.value.error.message assert "Error in test_method: ValueError: test error" in caplog.text + async def test_on_message_error_transform_tool_error(self, mock_context, caplog): + """Test error handling with transformation and cause type.""" + middleware = ErrorHandlingMiddleware() + tool_error = ToolError("test error") + tool_error.__cause__ = ValueError() + mock_call_next = AsyncMock(side_effect=tool_error) + + with caplog_for_fastmcp(caplog): + with caplog.at_level(logging.ERROR): + with pytest.raises(McpError) as exc_info: + await middleware.on_message(mock_context, mock_call_next) + + assert isinstance(exc_info.value, McpError) + assert exc_info.value.error.code == -32602 + assert "Invalid params: test error" in exc_info.value.error.message + assert "Error in test_method: ToolError: test error" in caplog.text + def test_get_error_stats(self, mock_context): """Test getting error statistics.""" middleware = ErrorHandlingMiddleware()