diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 9d634b73999..8fb5f745745 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -16,11 +16,12 @@ package org.springframework.ai.mcp; -import java.util.Map; - import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; +import java.util.Map; +import reactor.core.publisher.Mono; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; @@ -112,19 +113,16 @@ public String call(String functionInput) { Map arguments = ModelOptionsUtils.jsonToMap(functionInput); // Note that we use the original tool name here, not the adapted one from // getToolDefinition - try { - return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).map(response -> { - if (response.isError() != null && response.isError()) { - throw new ToolExecutionException(this.getToolDefinition(), - new IllegalStateException("Error calling tool: " + response.content())); - } - return ModelOptionsUtils.toJsonString(response.content()); - }).block(); - } - catch (Exception ex) { - throw new ToolExecutionException(this.getToolDefinition(), ex.getCause()); - } - + return this.asyncMcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)).onErrorMap(exception -> { + // If the tool throws an error during execution + throw new ToolExecutionException(this.getToolDefinition(), exception); + }).map(response -> { + if (response.isError() != null && response.isError()) { + throw new ToolExecutionException(this.getToolDefinition(), + new IllegalStateException("Error calling tool: " + response.content())); + } + return ModelOptionsUtils.toJsonString(response.content()); + }).block(); } @Override diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index 442f21eb89a..fc61d801df1 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -16,13 +16,11 @@ package org.springframework.ai.mcp; -import java.lang.reflect.InvocationTargetException; -import java.util.Map; - import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +30,6 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; -import org.springframework.core.log.LogAccessor; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -118,22 +115,24 @@ public ToolDefinition getToolDefinition() { @Override public String call(String functionInput) { Map arguments = ModelOptionsUtils.jsonToMap(functionInput); - // Note that we use the original tool name here, not the adapted one from - // getToolDefinition + + CallToolResult response; try { - CallToolResult response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)); - if (response.isError() != null && response.isError()) { - logger.error("Error calling tool: {}", response.content()); - throw new ToolExecutionException(this.getToolDefinition(), - new IllegalStateException("Error calling tool: " + response.content())); - } - return ModelOptionsUtils.toJsonString(response.content()); + // Note that we use the original tool name here, not the adapted one from + // getToolDefinition + response = this.mcpClient.callTool(new CallToolRequest(this.tool.name(), arguments)); } catch (Exception ex) { logger.error("Exception while tool calling: ", ex); - throw new ToolExecutionException(this.getToolDefinition(), ex.getCause()); + throw new ToolExecutionException(this.getToolDefinition(), ex); } + if (response.isError() != null && response.isError()) { + logger.error("Error calling tool: {}", response.content()); + throw new ToolExecutionException(this.getToolDefinition(), + new IllegalStateException("Error calling tool: " + response.content())); + } + return ModelOptionsUtils.toJsonString(response.content()); } @Override diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java new file mode 100644 index 00000000000..e4ceb618efd --- /dev/null +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java @@ -0,0 +1,54 @@ +package org.springframework.ai.mcp; + +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Mono; + +import org.springframework.ai.tool.execution.ToolExecutionException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +class AsyncMcpToolCallbackTest { + + @Mock + private McpAsyncClient mcpClient; + + @Mock + private McpSchema.Tool tool; + + @Test + void callShouldThrowOnError() { + when(this.tool.name()).thenReturn("testTool"); + var clientInfo = new McpSchema.Implementation("testClient", "1.0.0"); + when(this.mcpClient.getClientInfo()).thenReturn(clientInfo); + var callToolResult = McpSchema.CallToolResult.builder().addTextContent("Some error data").isError(true).build(); + when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))).thenReturn(Mono.just(callToolResult)); + + var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool); + assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessage("Error calling tool: [TextContent[audience=null, priority=null, text=Some error data]]"); + } + + @Test + void callShouldWrapReactiveErrors() { + when(this.tool.name()).thenReturn("testTool"); + var clientInfo = new McpSchema.Implementation("testClient", "1.0.0"); + when(this.mcpClient.getClientInfo()).thenReturn(clientInfo); + when(this.mcpClient.callTool(any(McpSchema.CallToolRequest.class))) + .thenReturn(Mono.error(new Exception("Testing tool error"))); + + var callback = new AsyncMcpToolCallback(this.mcpClient, this.tool); + assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class) + .rootCause() + .hasMessage("Testing tool error"); + } + +} diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index 4ed2483e64f..72b04394a49 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -16,6 +16,8 @@ package org.springframework.ai.mcp; +import io.modelcontextprotocol.spec.McpSchema; +import java.util.List; import java.util.Map; import io.modelcontextprotocol.client.McpSyncClient; @@ -29,8 +31,11 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.content.Content; +import org.springframework.ai.tool.execution.ToolExecutionException; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -94,4 +99,36 @@ void callShouldIgnoreToolContext() { assertThat(response).isNotNull(); } + @Test + void callShouldThrowOnError() { + when(this.tool.name()).thenReturn("testTool"); + var clientInfo = new Implementation("testClient", "1.0.0"); + when(this.mcpClient.getClientInfo()).thenReturn(clientInfo); + CallToolResult callResult = mock(CallToolResult.class); + when(callResult.isError()).thenReturn(true); + when(callResult.content()).thenReturn(List.of(new McpSchema.TextContent("Some error data"))); + when(this.mcpClient.callTool(any(CallToolRequest.class))).thenReturn(callResult); + + SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool); + + assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class) + .cause() + .isInstanceOf(IllegalStateException.class) + .hasMessage("Error calling tool: [TextContent[audience=null, priority=null, text=Some error data]]"); + } + + @Test + void callShouldWrapExceptions() { + when(this.tool.name()).thenReturn("testTool"); + var clientInfo = new Implementation("testClient", "1.0.0"); + when(this.mcpClient.getClientInfo()).thenReturn(clientInfo); + when(this.mcpClient.callTool(any(CallToolRequest.class))).thenThrow(new RuntimeException("Testing tool error")); + + SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool); + + assertThatThrownBy(() -> callback.call("{\"param\":\"value\"}")).isInstanceOf(ToolExecutionException.class) + .rootCause() + .hasMessage("Testing tool error"); + } + }