diff --git a/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts b/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts new file mode 100644 index 00000000000..9e33b0552c9 --- /dev/null +++ b/src/core/prompts/tools/native-tools/__tests__/mcp_server.spec.ts @@ -0,0 +1,195 @@ +import type OpenAI from "openai" +import { getMcpServerTools } from "../mcp_server" +import type { McpHub } from "../../../../../services/mcp/McpHub" +import type { McpServer, McpTool } from "../../../../../shared/mcp" + +// Helper type to access function tools +type FunctionTool = OpenAI.Chat.ChatCompletionTool & { type: "function" } + +// Helper to get the function property from a tool +const getFunction = (tool: OpenAI.Chat.ChatCompletionTool) => (tool as FunctionTool).function + +describe("getMcpServerTools", () => { + const createMockTool = (name: string, description = "Test tool"): McpTool => ({ + name, + description, + inputSchema: { + type: "object", + properties: {}, + }, + }) + + const createMockServer = (name: string, tools: McpTool[], source: "global" | "project" = "global"): McpServer => ({ + name, + config: JSON.stringify({ type: "stdio", command: "test" }), + status: "connected", + source, + tools, + }) + + const createMockMcpHub = (servers: McpServer[]): Partial => ({ + getServers: vi.fn().mockReturnValue(servers), + }) + + it("should return empty array when mcpHub is undefined", () => { + const result = getMcpServerTools(undefined) + expect(result).toEqual([]) + }) + + it("should return empty array when no servers are available", () => { + const mockHub = createMockMcpHub([]) + const result = getMcpServerTools(mockHub as McpHub) + expect(result).toEqual([]) + }) + + it("should generate tool definitions for server tools", () => { + const server = createMockServer("testServer", [createMockTool("testTool")]) + const mockHub = createMockMcpHub([server]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(1) + expect(result[0].type).toBe("function") + expect(getFunction(result[0]).name).toBe("mcp--testServer--testTool") + expect(getFunction(result[0]).description).toBe("Test tool") + }) + + it("should filter out tools with enabledForPrompt set to false", () => { + const enabledTool = createMockTool("enabledTool") + const disabledTool = { ...createMockTool("disabledTool"), enabledForPrompt: false } + const server = createMockServer("testServer", [enabledTool, disabledTool]) + const mockHub = createMockMcpHub([server]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(1) + expect(getFunction(result[0]).name).toBe("mcp--testServer--enabledTool") + }) + + it("should deduplicate tools when same server exists in both global and project configs", () => { + const globalServer = createMockServer( + "context7", + [createMockTool("resolve-library-id", "Global description")], + "global", + ) + const projectServer = createMockServer( + "context7", + [createMockTool("resolve-library-id", "Project description")], + "project", + ) + + // McpHub.getServers() deduplicates with project servers taking priority + // This test simulates the deduplicated result (only project server returned) + const mockHub = createMockMcpHub([projectServer]) + + const result = getMcpServerTools(mockHub as McpHub) + + // Should only have one tool (from project server) + expect(result).toHaveLength(1) + expect(getFunction(result[0]).name).toBe("mcp--context7--resolve-library-id") + // Project server takes priority + expect(getFunction(result[0]).description).toBe("Project description") + }) + + it("should allow tools with different names from the same server", () => { + const server = createMockServer("testServer", [ + createMockTool("tool1"), + createMockTool("tool2"), + createMockTool("tool3"), + ]) + const mockHub = createMockMcpHub([server]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(3) + const toolNames = result.map((t) => getFunction(t).name) + expect(toolNames).toContain("mcp--testServer--tool1") + expect(toolNames).toContain("mcp--testServer--tool2") + expect(toolNames).toContain("mcp--testServer--tool3") + }) + + it("should allow tools with same name from different servers", () => { + const server1 = createMockServer("server1", [createMockTool("commonTool")]) + const server2 = createMockServer("server2", [createMockTool("commonTool")]) + const mockHub = createMockMcpHub([server1, server2]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(2) + const toolNames = result.map((t) => getFunction(t).name) + expect(toolNames).toContain("mcp--server1--commonTool") + expect(toolNames).toContain("mcp--server2--commonTool") + }) + + it("should skip servers without tools", () => { + const serverWithTools = createMockServer("withTools", [createMockTool("tool1")]) + const serverWithoutTools = createMockServer("withoutTools", []) + const serverWithUndefinedTools: McpServer = { + ...createMockServer("undefinedTools", []), + tools: undefined, + } + const mockHub = createMockMcpHub([serverWithTools, serverWithoutTools, serverWithUndefinedTools]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(1) + expect(getFunction(result[0]).name).toBe("mcp--withTools--tool1") + }) + + it("should include required fields from tool schema", () => { + const toolWithRequired: McpTool = { + name: "toolWithRequired", + description: "Tool with required fields", + inputSchema: { + type: "object", + properties: { + requiredField: { type: "string" }, + optionalField: { type: "number" }, + }, + required: ["requiredField"], + }, + } + const server = createMockServer("testServer", [toolWithRequired]) + const mockHub = createMockMcpHub([server]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(1) + expect(getFunction(result[0]).parameters).toEqual({ + type: "object", + properties: { + requiredField: { type: "string" }, + optionalField: { type: "number" }, + }, + additionalProperties: false, + required: ["requiredField"], + }) + }) + + it("should not include required field when schema has no required fields", () => { + const toolWithoutRequired: McpTool = { + name: "toolWithoutRequired", + description: "Tool without required fields", + inputSchema: { + type: "object", + properties: { + optionalField: { type: "string" }, + }, + }, + } + const server = createMockServer("testServer", [toolWithoutRequired]) + const mockHub = createMockMcpHub([server]) + + const result = getMcpServerTools(mockHub as McpHub) + + expect(result).toHaveLength(1) + expect(getFunction(result[0]).parameters).toEqual({ + type: "object", + properties: { + optionalField: { type: "string" }, + }, + additionalProperties: false, + }) + expect(getFunction(result[0]).parameters).not.toHaveProperty("required") + }) +}) diff --git a/src/core/prompts/tools/native-tools/mcp_server.ts b/src/core/prompts/tools/native-tools/mcp_server.ts index f40da7cf500..3b47f84adf4 100644 --- a/src/core/prompts/tools/native-tools/mcp_server.ts +++ b/src/core/prompts/tools/native-tools/mcp_server.ts @@ -4,6 +4,8 @@ import { buildMcpToolName } from "../../../../utils/mcp-name" /** * Dynamically generates native tool definitions for all enabled tools across connected MCP servers. + * Tools are deduplicated by name to prevent API errors. When the same server exists in both + * global and project configs, project servers take priority (handled by McpHub.getServers()). * * @param mcpHub The McpHub instance containing connected servers. * @returns An array of OpenAI.Chat.ChatCompletionTool definitions. @@ -15,6 +17,8 @@ export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTo const servers = mcpHub.getServers() const tools: OpenAI.Chat.ChatCompletionTool[] = [] + // Track seen tool names to prevent duplicates (e.g., when same server exists in both global and project configs) + const seenToolNames = new Set() for (const server of servers) { if (!server.tools) { @@ -26,6 +30,16 @@ export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTo continue } + // Build sanitized tool name for API compliance + // The name is sanitized to conform to API requirements (e.g., Gemini's function name restrictions) + const toolName = buildMcpToolName(server.name, tool.name) + + // Skip duplicate tool names - first occurrence wins (project servers come before global servers) + if (seenToolNames.has(toolName)) { + continue + } + seenToolNames.add(toolName) + const originalSchema = tool.inputSchema as Record | undefined const toolInputProps = originalSchema?.properties ?? {} const toolInputRequired = (originalSchema?.required ?? []) as string[] @@ -44,10 +58,6 @@ export function getMcpServerTools(mcpHub?: McpHub): OpenAI.Chat.ChatCompletionTo parameters.required = toolInputRequired } - // Build sanitized tool name for API compliance - // The name is sanitized to conform to API requirements (e.g., Gemini's function name restrictions) - const toolName = buildMcpToolName(server.name, tool.name) - const toolDefinition: OpenAI.Chat.ChatCompletionTool = { type: "function", function: { diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 3d54cb670e2..1c35c8b89f2 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -435,8 +435,23 @@ export class McpHub { } getServers(): McpServer[] { - // Only return enabled servers - return this.connections.filter((conn) => !conn.server.disabled).map((conn) => conn.server) + // Only return enabled servers, deduplicating by name with project servers taking priority + const enabledConnections = this.connections.filter((conn) => !conn.server.disabled) + + // Deduplicate by server name: project servers take priority over global servers + const serversByName = new Map() + for (const conn of enabledConnections) { + const existing = serversByName.get(conn.server.name) + if (!existing) { + serversByName.set(conn.server.name, conn.server) + } else if (conn.server.source === "project" && existing.source !== "project") { + // Project server overrides global server with the same name + serversByName.set(conn.server.name, conn.server) + } + // If existing is project and current is global, keep existing (project wins) + } + + return Array.from(serversByName.values()) } getAllServers(): McpServer[] { diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 1db924ed6cc..2d895fdbca5 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -1136,6 +1136,87 @@ describe("McpHub", () => { expect(servers[0].name).toBe("enabled-server") }) + it("should deduplicate servers by name with project servers taking priority", () => { + const mockConnections: McpConnection[] = [ + { + type: "connected", + server: { + name: "shared-server", + config: '{"source":"global"}', + status: "connected", + disabled: false, + source: "global", + }, + client: {} as any, + transport: {} as any, + } as ConnectedMcpConnection, + { + type: "connected", + server: { + name: "shared-server", + config: '{"source":"project"}', + status: "connected", + disabled: false, + source: "project", + }, + client: {} as any, + transport: {} as any, + } as ConnectedMcpConnection, + { + type: "connected", + server: { + name: "unique-global-server", + config: "{}", + status: "connected", + disabled: false, + source: "global", + }, + client: {} as any, + transport: {} as any, + } as ConnectedMcpConnection, + ] + + mcpHub.connections = mockConnections + const servers = mcpHub.getServers() + + // Should have 2 servers: deduplicated "shared-server" + "unique-global-server" + expect(servers.length).toBe(2) + + // Find the shared-server - it should be the project version + const sharedServer = servers.find((s) => s.name === "shared-server") + expect(sharedServer).toBeDefined() + expect(sharedServer!.source).toBe("project") + expect(sharedServer!.config).toBe('{"source":"project"}') + + // The unique global server should also be present + const uniqueServer = servers.find((s) => s.name === "unique-global-server") + expect(uniqueServer).toBeDefined() + }) + + it("should keep global server when no project server with same name exists", () => { + const mockConnections: McpConnection[] = [ + { + type: "connected", + server: { + name: "global-only-server", + config: "{}", + status: "connected", + disabled: false, + source: "global", + }, + client: {} as any, + transport: {} as any, + } as ConnectedMcpConnection, + ] + + mcpHub.connections = mockConnections + const servers = mcpHub.getServers() + + expect(servers.length).toBe(1) + expect(servers[0].name).toBe("global-only-server") + expect(servers[0].source).toBe("global") + }) + it("should prevent calling tools on disabled servers", async () => { // Mock fs.readFile to return a disabled server config vi.mocked(fs.readFile).mockResolvedValue(