From a2a8bfcc0e1e0dd71da229d5138f9b5770734160 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Tue, 22 Jul 2025 22:28:05 +0000 Subject: [PATCH 1/6] fix: prevent disabled MCP servers from starting processes and show correct status - Backend: Skip connecting to disabled servers in connectToServer() - Backend: Handle enable/disable state changes properly in toggleServerDisabled() - Backend: Only setup file watchers for enabled servers - Frontend: Show grey status indicator for disabled servers - Frontend: Hide error messages and retry buttons for disabled servers - Frontend: Prevent expansion of disabled server rows Fixes #6036 --- src/services/mcp/McpHub.ts | 45 ++- webview-ui/src/components/mcp/McpView.tsx | 342 ++++++++++++---------- 2 files changed, 225 insertions(+), 162 deletions(-) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 10a74712ef..4abae7874a 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -497,6 +497,7 @@ export class McpHub { const result = McpSettingsSchema.safeParse(config) if (result.success) { + // Pass all servers including disabled ones - they'll be handled in updateServerConnections await this.updateServerConnections(result.data.mcpServers || {}, source, false) } else { const errorMessages = result.error.errors @@ -560,6 +561,26 @@ export class McpHub { // Remove existing connection if it exists with the same source await this.deleteConnection(name, source) + // Skip connecting to disabled servers + if (config.disabled) { + // Still create a connection object to track the server, but don't actually connect + const connection: McpConnection = { + server: { + name, + config: JSON.stringify(config), + status: "disconnected", + disabled: true, + source, + projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, + errorHistory: [], + }, + client: null as any, // We won't actually create a client for disabled servers + transport: null as any, // We won't actually create a transport for disabled servers + } + this.connections.push(connection) + return + } + try { const client = new Client( { @@ -975,7 +996,10 @@ export class McpHub { if (!currentConnection) { // New server try { - this.setupFileWatcher(name, validatedConfig, source) + // Only setup file watcher for enabled servers + if (!validatedConfig.disabled) { + this.setupFileWatcher(name, validatedConfig, source) + } await this.connectToServer(name, validatedConfig, source) } catch (error) { this.showErrorMessage(`Failed to connect to new MCP server ${name}`, error) @@ -983,7 +1007,10 @@ export class McpHub { } else if (!deepEqual(JSON.parse(currentConnection.server.config), config)) { // Existing server with changed config try { - this.setupFileWatcher(name, validatedConfig, source) + // Only setup file watcher for enabled servers + if (!validatedConfig.disabled) { + this.setupFileWatcher(name, validatedConfig, source) + } await this.deleteConnection(name, source) await this.connectToServer(name, validatedConfig, source) } catch (error) { @@ -1257,8 +1284,18 @@ export class McpHub { try { connection.server.disabled = disabled - // Only refresh capabilities if connected - if (connection.server.status === "connected") { + // If disabling a connected server, disconnect it + if (disabled && connection.server.status === "connected") { + await this.deleteConnection(serverName, serverSource) + // Re-add as a disabled connection + await this.connectToServer(serverName, JSON.parse(connection.server.config), serverSource) + } else if (!disabled && connection.server.status === "disconnected") { + // If enabling a disabled server, connect it + const config = JSON.parse(connection.server.config) + await this.deleteConnection(serverName, serverSource) + await this.connectToServer(serverName, config, serverSource) + } else if (connection.server.status === "connected") { + // Only refresh capabilities if connected connection.server.tools = await this.fetchToolsList(serverName, serverSource) connection.server.resources = await this.fetchResourcesList(serverName, serverSource) connection.server.resourceTemplates = await this.fetchResourceTemplatesList( diff --git a/webview-ui/src/components/mcp/McpView.tsx b/webview-ui/src/components/mcp/McpView.tsx index 0873bde195..39a537f222 100644 --- a/webview-ui/src/components/mcp/McpView.tsx +++ b/webview-ui/src/components/mcp/McpView.tsx @@ -218,6 +218,11 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM ] const getStatusColor = () => { + // Disabled servers should always show grey regardless of connection status + if (server.disabled) { + return "var(--vscode-descriptionForeground)" + } + switch (server.status) { case "connected": return "var(--vscode-testing-iconPassed)" @@ -229,7 +234,8 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM } const handleRowClick = () => { - if (server.status === "connected") { + // Only allow expansion for connected and enabled servers + if (server.status === "connected" && !server.disabled) { setIsExpanded(!isExpanded) } } @@ -270,12 +276,13 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM alignItems: "center", padding: "8px", background: "var(--vscode-textCodeBlock-background)", - cursor: server.status === "connected" ? "pointer" : "default", - borderRadius: isExpanded || server.status === "connected" ? "4px" : "4px 4px 0 0", + cursor: server.status === "connected" && !server.disabled ? "pointer" : "default", + borderRadius: + isExpanded || (server.status === "connected" && !server.disabled) ? "4px" : "4px 4px 0 0", opacity: server.disabled ? 0.6 : 1, }} onClick={handleRowClick}> - {server.status === "connected" && ( + {server.status === "connected" && !server.disabled && ( - {server.status === "connected" ? ( - isExpanded && ( -
- - - {t("mcp:tabs.tools")} ({server.tools?.length || 0}) - - - {t("mcp:tabs.resources")} ( - {[...(server.resourceTemplates || []), ...(server.resources || [])].length || 0}) - - {server.instructions && ( - {t("mcp:instructions")} - )} - - {t("mcp:tabs.errors")} ({server.errorHistory?.length || 0}) - - - - {server.tools && server.tools.length > 0 ? ( -
- {server.tools.map((tool) => ( - - ))} -
- ) : ( -
- {t("mcp:emptyState.noTools")} -
+ {server.status === "connected" && !server.disabled + ? isExpanded && ( +
+ + + {t("mcp:tabs.tools")} ({server.tools?.length || 0}) + + + {t("mcp:tabs.resources")} ( + {[...(server.resourceTemplates || []), ...(server.resources || [])].length || 0}) + + {server.instructions && ( + {t("mcp:instructions")} )} - + + {t("mcp:tabs.errors")} ({server.errorHistory?.length || 0}) + - - {(server.resources && server.resources.length > 0) || - (server.resourceTemplates && server.resourceTemplates.length > 0) ? ( -
- {[...(server.resourceTemplates || []), ...(server.resources || [])].map( - (item) => ( - + {server.tools && server.tools.length > 0 ? ( +
+ {server.tools.map((tool) => ( + - ), - )} -
- ) : ( -
- {t("mcp:emptyState.noResources")} -
- )} - + ))} +
+ ) : ( +
+ {t("mcp:emptyState.noTools")} +
+ )} +
- {server.instructions && ( - -
-
- {server.instructions} + + {(server.resources && server.resources.length > 0) || + (server.resourceTemplates && server.resourceTemplates.length > 0) ? ( +
+ {[...(server.resourceTemplates || []), ...(server.resources || [])].map( + (item) => ( + + ), + )} +
+ ) : ( +
+ {t("mcp:emptyState.noResources")}
-
+ )} - )} - - {server.errorHistory && server.errorHistory.length > 0 ? ( -
- {[...server.errorHistory] - .sort((a, b) => b.timestamp - a.timestamp) - .map((error, index) => ( - - ))} -
- ) : ( -
- {t("mcp:emptyState.noErrors")} -
+ {server.instructions && ( + +
+
+ {server.instructions} +
+
+
)} -
- - {/* Network Timeout */} -
+ + {server.errorHistory && server.errorHistory.length > 0 ? ( +
+ {[...server.errorHistory] + .sort((a, b) => b.timestamp - a.timestamp) + .map((error, index) => ( + + ))} +
+ ) : ( +
+ {t("mcp:emptyState.noErrors")} +
+ )} +
+ + + {/* Network Timeout */} +
+
+ {t("mcp:networkTimeout.label")} + +
+ + {t("mcp:networkTimeout.description")} + +
+
+ ) + : // Only show error UI for non-disabled servers + !server.disabled && ( +
- {t("mcp:networkTimeout.label")} -
- - {t("mcp:networkTimeout.description")} - + + {server.status === "connecting" + ? t("mcp:serverStatus.retrying") + : t("mcp:serverStatus.retryConnection")} +
-
- ) - ) : ( -
-
- {server.error && - server.error.split("\n").map((item, index) => ( - - {index > 0 &&
} - {item} -
- ))} -
- - {server.status === "connecting" - ? t("mcp:serverStatus.retrying") - : t("mcp:serverStatus.retryConnection")} - -
- )} + )} {/* Delete Confirmation Dialog */} From 9ee690e93ca2bfd7bb56ecdcd5f9a02bf16955b0 Mon Sep 17 00:00:00 2001 From: hannesrudolph Date: Tue, 22 Jul 2025 23:14:57 -0600 Subject: [PATCH 2/6] fix: prevent MCP servers from starting when MCP is globally disabled - Added check for global mcpEnabled state in connectToServer method - Updated refreshAllConnections to respect global MCP setting - Updated restartConnection to check global MCP state - Added tests to verify servers don't start when MCP is disabled - Ensures disabled servers show correct status in UI --- src/services/mcp/McpHub.ts | 61 +++++++++++ src/services/mcp/__tests__/McpHub.spec.ts | 117 ++++++++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index 4abae7874a..b5229b6eaf 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -561,6 +561,34 @@ export class McpHub { // Remove existing connection if it exists with the same source await this.deleteConnection(name, source) + // Check if MCP is globally enabled + const provider = this.providerRef.deref() + if (provider) { + const state = await provider.getState() + const mcpEnabled = state.mcpEnabled ?? true + + // Skip connecting if MCP is globally disabled + if (!mcpEnabled) { + // Still create a connection object to track the server, but don't actually connect + const connection: McpConnection = { + server: { + name, + config: JSON.stringify(config), + status: "disconnected", + disabled: config.disabled, + source, + projectPath: + source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, + errorHistory: [], + }, + client: null as any, // We won't actually create a client when MCP is disabled + transport: null as any, // We won't actually create a transport when MCP is disabled + } + this.connections.push(connection) + return + } + } + // Skip connecting to disabled servers if (config.disabled) { // Still create a connection object to track the server, but don't actually connect @@ -1100,6 +1128,16 @@ export class McpHub { return } + // Check if MCP is globally enabled + const state = await provider.getState() + const mcpEnabled = state.mcpEnabled ?? true + + // Skip restarting if MCP is globally disabled + if (!mcpEnabled) { + this.isConnecting = false + return + } + // Get existing connection and update its status const connection = this.findConnection(serverName, source) const config = connection?.server.config @@ -1138,6 +1176,29 @@ export class McpHub { return } + // Check if MCP is globally enabled + const provider = this.providerRef.deref() + if (provider) { + const state = await provider.getState() + const mcpEnabled = state.mcpEnabled ?? true + + // Skip refreshing if MCP is globally disabled + if (!mcpEnabled) { + // Clear all existing connections + const existingConnections = [...this.connections] + for (const conn of existingConnections) { + await this.deleteConnection(conn.server.name, conn.server.source) + } + + // Still initialize servers to track them, but they won't connect + await this.initializeMcpServers("global") + await this.initializeMcpServers("project") + + await this.notifyWebviewOfServerChanges() + return + } + } + this.isConnecting = true vscode.window.showInformationMessage(t("mcp:info.refreshing_all")) diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 7dc7f00c04..f3f1640c4e 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -108,6 +108,7 @@ describe("McpHub", () => { ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: true }), context: { subscriptions: [], workspaceState: {} as any, @@ -877,6 +878,122 @@ describe("McpHub", () => { }) }) + describe("MCP global enable/disable", () => { + beforeEach(() => { + // Clear all mocks before each test + vi.clearAllMocks() + }) + + it("should not connect to servers when MCP is globally disabled", async () => { + // Mock provider with mcpEnabled: false + const disabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: false }), + context: mockProvider.context, + } + + // Mock the config file read with a different server name to avoid conflicts + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create a new McpHub instance with disabled MCP + const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the disabled-test-server + const disabledServer = mcpHub.connections.find((conn) => conn.server.name === "disabled-test-server") + + // Verify that the server is tracked but not connected + expect(disabledServer).toBeDefined() + expect(disabledServer!.server.status).toBe("disconnected") + expect(disabledServer!.client).toBeNull() + expect(disabledServer!.transport).toBeNull() + }) + + it("should connect to servers when MCP is globally enabled", async () => { + // Clear all mocks + vi.clearAllMocks() + + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Mock provider with mcpEnabled: true + const enabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: true }), + context: mockProvider.context, + } + + // Mock the config file read with a different server name + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "enabled-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create a new McpHub instance with enabled MCP + const mcpHub = new McpHub(enabledMockProvider as unknown as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the enabled-test-server + const enabledServer = mcpHub.connections.find((conn) => conn.server.name === "enabled-test-server") + + // Verify that the server is connected + expect(enabledServer).toBeDefined() + expect(enabledServer!.server.status).toBe("connected") + expect(enabledServer!.client).toBeDefined() + expect(enabledServer!.transport).toBeDefined() + + // Verify StdioClientTransport was called + expect(StdioClientTransport).toHaveBeenCalled() + }) + }) + describe("Windows command wrapping", () => { let StdioClientTransport: ReturnType let Client: ReturnType From 59e9f87d95ba0d202c366e426f27e907c53ac7bb Mon Sep 17 00:00:00 2001 From: hannesrudolph Date: Tue, 22 Jul 2025 23:19:15 -0600 Subject: [PATCH 3/6] fix: disconnect MCP servers immediately when MCP is disabled - Modified webviewMessageHandler to actively disconnect all running MCP servers when the 'Enable MCP Servers' checkbox is unchecked - Added logic to reconnect servers when MCP is re-enabled - This ensures servers are shut down immediately without requiring a plugin reload --- src/core/webview/webviewMessageHandler.ts | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 780d40df89..993ad5dc71 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -880,6 +880,23 @@ export const webviewMessageHandler = async ( case "mcpEnabled": const mcpEnabled = message.bool ?? true await updateGlobalState("mcpEnabled", mcpEnabled) + + // If MCP is being disabled, disconnect all servers + const mcpHubInstance = provider.getMcpHub() + if (!mcpEnabled && mcpHubInstance) { + // Disconnect all existing connections + const existingConnections = [...mcpHubInstance.connections] + for (const conn of existingConnections) { + await mcpHubInstance.deleteConnection(conn.server.name, conn.server.source) + } + + // Re-initialize servers to track them in disconnected state + await mcpHubInstance.refreshAllConnections() + } else if (mcpEnabled && mcpHubInstance) { + // If MCP is being enabled, reconnect all servers + await mcpHubInstance.refreshAllConnections() + } + await provider.postStateToWebview() break case "enableMcpServerCreation": From 78b784b672a66843c0225c67e62145adc5993b73 Mon Sep 17 00:00:00 2001 From: hannesrudolph Date: Tue, 22 Jul 2025 23:39:59 -0600 Subject: [PATCH 4/6] fix: address PR #6084 review feedback - Extract duplicate placeholder connection creation logic into createPlaceholderConnection helper - Fix type safety by allowing null client/transport in McpConnection type - Add proper null checks throughout the codebase - Add error handling to server disconnection loop in webviewMessageHandler - Centralize MCP enabled state checking with isMcpEnabled helper - Extract repeated UI condition into isExpandable computed property - Add comprehensive test coverage for edge cases including: - Toggling global MCP enabled state while servers are active - Handling refreshAllConnections when MCP is disabled - Skipping connection restart when MCP is disabled --- src/core/webview/webviewMessageHandler.ts | 45 +++++- src/services/mcp/McpHub.ts | 150 +++++++++--------- src/services/mcp/__tests__/McpHub.spec.ts | 177 +++++++++++++++++++++- webview-ui/src/components/mcp/McpView.tsx | 14 +- 4 files changed, 298 insertions(+), 88 deletions(-) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 993ad5dc71..2ee2ee3485 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -884,17 +884,54 @@ export const webviewMessageHandler = async ( // If MCP is being disabled, disconnect all servers const mcpHubInstance = provider.getMcpHub() if (!mcpEnabled && mcpHubInstance) { - // Disconnect all existing connections + // Disconnect all existing connections with error handling const existingConnections = [...mcpHubInstance.connections] + const disconnectionErrors: Array<{ serverName: string; error: string }> = [] + for (const conn of existingConnections) { - await mcpHubInstance.deleteConnection(conn.server.name, conn.server.source) + try { + await mcpHubInstance.deleteConnection(conn.server.name, conn.server.source) + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + disconnectionErrors.push({ + serverName: conn.server.name, + error: errorMessage, + }) + provider.log(`Failed to disconnect MCP server ${conn.server.name}: ${errorMessage}`) + } + } + + // If there were errors, notify the user + if (disconnectionErrors.length > 0) { + const errorSummary = disconnectionErrors.map((e) => `${e.serverName}: ${e.error}`).join("\n") + vscode.window.showWarningMessage( + t("mcp:errors.disconnect_servers_partial", { + count: disconnectionErrors.length, + errors: errorSummary, + }) || + `Failed to disconnect ${disconnectionErrors.length} MCP server(s). Check the output for details.`, + ) } // Re-initialize servers to track them in disconnected state - await mcpHubInstance.refreshAllConnections() + try { + await mcpHubInstance.refreshAllConnections() + } catch (error) { + provider.log(`Failed to refresh MCP connections after disabling: ${error}`) + vscode.window.showErrorMessage( + t("mcp:errors.refresh_after_disable") || "Failed to refresh MCP connections after disabling", + ) + } } else if (mcpEnabled && mcpHubInstance) { // If MCP is being enabled, reconnect all servers - await mcpHubInstance.refreshAllConnections() + try { + await mcpHubInstance.refreshAllConnections() + } catch (error) { + provider.log(`Failed to refresh MCP connections after enabling: ${error}`) + vscode.window.showErrorMessage( + t("mcp:errors.refresh_after_enable") || "Failed to refresh MCP connections after enabling", + ) + } } await provider.postStateToWebview() diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index b5229b6eaf..cb6783fb7a 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -35,8 +35,8 @@ import { injectVariables } from "../../utils/config" export type McpConnection = { server: McpServer - client: Client - transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport + client: Client | null + transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport | null } // Base configuration schema for common settings @@ -553,6 +553,48 @@ export class McpHub { await this.initializeMcpServers("project") } + /** + * Creates a placeholder connection for disabled servers or when MCP is globally disabled + * @param name The server name + * @param config The server configuration + * @param source The source of the server (global or project) + * @param reason The reason for creating a placeholder (mcpDisabled or serverDisabled) + * @returns A placeholder McpConnection object + */ + private createPlaceholderConnection( + name: string, + config: z.infer, + source: "global" | "project", + reason: "mcpDisabled" | "serverDisabled", + ): McpConnection { + return { + server: { + name, + config: JSON.stringify(config), + status: "disconnected", + disabled: reason === "serverDisabled" ? true : config.disabled, + source, + projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, + errorHistory: [], + }, + client: null, + transport: null, + } + } + + /** + * Checks if MCP is globally enabled + * @returns Promise indicating if MCP is enabled + */ + private async isMcpEnabled(): Promise { + const provider = this.providerRef.deref() + if (!provider) { + return true // Default to enabled if provider is not available + } + const state = await provider.getState() + return state.mcpEnabled ?? true + } + private async connectToServer( name: string, config: z.infer, @@ -562,49 +604,18 @@ export class McpHub { await this.deleteConnection(name, source) // Check if MCP is globally enabled - const provider = this.providerRef.deref() - if (provider) { - const state = await provider.getState() - const mcpEnabled = state.mcpEnabled ?? true - - // Skip connecting if MCP is globally disabled - if (!mcpEnabled) { - // Still create a connection object to track the server, but don't actually connect - const connection: McpConnection = { - server: { - name, - config: JSON.stringify(config), - status: "disconnected", - disabled: config.disabled, - source, - projectPath: - source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, - errorHistory: [], - }, - client: null as any, // We won't actually create a client when MCP is disabled - transport: null as any, // We won't actually create a transport when MCP is disabled - } - this.connections.push(connection) - return - } + const mcpEnabled = await this.isMcpEnabled() + if (!mcpEnabled) { + // Still create a connection object to track the server, but don't actually connect + const connection = this.createPlaceholderConnection(name, config, source, "mcpDisabled") + this.connections.push(connection) + return } // Skip connecting to disabled servers if (config.disabled) { // Still create a connection object to track the server, but don't actually connect - const connection: McpConnection = { - server: { - name, - config: JSON.stringify(config), - status: "disconnected", - disabled: true, - source, - projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, - errorHistory: [], - }, - client: null as any, // We won't actually create a client for disabled servers - transport: null as any, // We won't actually create a transport for disabled servers - } + const connection = this.createPlaceholderConnection(name, config, source, "serverDisabled") this.connections.push(connection) return } @@ -875,8 +886,8 @@ export class McpHub { // Use the helper method to find the connection const connection = this.findConnection(serverName, source) - if (!connection) { - throw new Error(`Server ${serverName} not found`) + if (!connection || !connection.client) { + throw new Error(`Server ${serverName} not found or not connected`) } const response = await connection.client.request({ method: "tools/list" }, ListToolsResultSchema) @@ -930,7 +941,7 @@ export class McpHub { private async fetchResourcesList(serverName: string, source?: "global" | "project"): Promise { try { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || !connection.client) { return [] } const response = await connection.client.request({ method: "resources/list" }, ListResourcesResultSchema) @@ -947,7 +958,7 @@ export class McpHub { ): Promise { try { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || !connection.client) { return [] } const response = await connection.client.request( @@ -969,8 +980,12 @@ export class McpHub { for (const connection of connections) { try { - await connection.transport.close() - await connection.client.close() + if (connection.transport) { + await connection.transport.close() + } + if (connection.client) { + await connection.client.close() + } } catch (error) { console.error(`Failed to close transport for ${name}:`, error) } @@ -1123,16 +1138,9 @@ export class McpHub { async restartConnection(serverName: string, source?: "global" | "project"): Promise { this.isConnecting = true - const provider = this.providerRef.deref() - if (!provider) { - return - } // Check if MCP is globally enabled - const state = await provider.getState() - const mcpEnabled = state.mcpEnabled ?? true - - // Skip restarting if MCP is globally disabled + const mcpEnabled = await this.isMcpEnabled() if (!mcpEnabled) { this.isConnecting = false return @@ -1177,26 +1185,20 @@ export class McpHub { } // Check if MCP is globally enabled - const provider = this.providerRef.deref() - if (provider) { - const state = await provider.getState() - const mcpEnabled = state.mcpEnabled ?? true - - // Skip refreshing if MCP is globally disabled - if (!mcpEnabled) { - // Clear all existing connections - const existingConnections = [...this.connections] - for (const conn of existingConnections) { - await this.deleteConnection(conn.server.name, conn.server.source) - } + const mcpEnabled = await this.isMcpEnabled() + if (!mcpEnabled) { + // Clear all existing connections + const existingConnections = [...this.connections] + for (const conn of existingConnections) { + await this.deleteConnection(conn.server.name, conn.server.source) + } - // Still initialize servers to track them, but they won't connect - await this.initializeMcpServers("global") - await this.initializeMcpServers("project") + // Still initialize servers to track them, but they won't connect + await this.initializeMcpServers("global") + await this.initializeMcpServers("project") - await this.notifyWebviewOfServerChanges() - return - } + await this.notifyWebviewOfServerChanges() + return } this.isConnecting = true @@ -1537,7 +1539,7 @@ export class McpHub { async readResource(serverName: string, uri: string, source?: "global" | "project"): Promise { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || !connection.client) { throw new Error(`No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}`) } if (connection.server.disabled) { @@ -1561,7 +1563,7 @@ export class McpHub { source?: "global" | "project", ): Promise { const connection = this.findConnection(serverName, source) - if (!connection) { + if (!connection || !connection.client) { throw new Error( `No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, ) diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index f3f1640c4e..737e7a2907 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -596,7 +596,7 @@ describe("McpHub", () => { await mcpHub.callTool("test-server", "some-tool", {}) // Verify the request was made with correct parameters - expect(mockConnection.client.request).toHaveBeenCalledWith( + expect(mockConnection.client!.request).toHaveBeenCalledWith( { method: "tools/call", params: { @@ -653,7 +653,7 @@ describe("McpHub", () => { mcpHub.connections = [mockConnection] await mcpHub.callTool("test-server", "test-tool") - expect(mockConnection.client.request).toHaveBeenCalledWith( + expect(mockConnection.client!.request).toHaveBeenCalledWith( expect.anything(), expect.anything(), expect.objectContaining({ timeout: 60000 }), // 60 seconds in milliseconds @@ -676,7 +676,7 @@ describe("McpHub", () => { mcpHub.connections = [mockConnection] await mcpHub.callTool("test-server", "test-tool") - expect(mockConnection.client.request).toHaveBeenCalledWith( + expect(mockConnection.client!.request).toHaveBeenCalledWith( expect.anything(), expect.anything(), expect.objectContaining({ timeout: 120000 }), // 120 seconds in milliseconds @@ -792,7 +792,7 @@ describe("McpHub", () => { await mcpHub.callTool("test-server", "test-tool") // Verify default timeout was used - expect(mockConnectionInvalid.client.request).toHaveBeenCalledWith( + expect(mockConnectionInvalid.client!.request).toHaveBeenCalledWith( expect.anything(), expect.anything(), expect.objectContaining({ timeout: 60000 }), // Default 60 seconds @@ -884,6 +884,85 @@ describe("McpHub", () => { vi.clearAllMocks() }) + it("should disconnect all servers when MCP is toggled from enabled to disabled", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Start with MCP enabled + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: true }) + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "toggle-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub and let it initialize with MCP enabled + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify server is connected + const connectedServer = mcpHub.connections.find((conn) => conn.server.name === "toggle-test-server") + expect(connectedServer).toBeDefined() + expect(connectedServer!.server.status).toBe("connected") + expect(connectedServer!.client).toBeDefined() + expect(connectedServer!.transport).toBeDefined() + + // Now simulate toggling MCP to disabled + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false }) + + // Manually trigger what would happen when MCP is disabled + // (normally this would be triggered by the webview message handler) + const existingConnections = [...mcpHub.connections] + for (const conn of existingConnections) { + await mcpHub.deleteConnection(conn.server.name, conn.server.source) + } + await mcpHub.refreshAllConnections() + + // Verify server is now tracked but disconnected + const disconnectedServer = mcpHub.connections.find((conn) => conn.server.name === "toggle-test-server") + expect(disconnectedServer).toBeDefined() + expect(disconnectedServer!.server.status).toBe("disconnected") + expect(disconnectedServer!.client).toBeNull() + expect(disconnectedServer!.transport).toBeNull() + + // Verify close was called on the original client and transport + expect(mockClient.close).toHaveBeenCalled() + expect(mockTransport.close).toHaveBeenCalled() + }) + it("should not connect to servers when MCP is globally disabled", async () => { // Mock provider with mcpEnabled: false const disabledMockProvider = { @@ -992,6 +1071,96 @@ describe("McpHub", () => { // Verify StdioClientTransport was called expect(StdioClientTransport).toHaveBeenCalled() }) + + it("should handle refreshAllConnections when MCP is disabled", async () => { + // Mock provider with mcpEnabled: false + const disabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: false }), + context: mockProvider.context, + } + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "refresh-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub with disabled MCP + const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Clear previous calls + vi.clearAllMocks() + + // Call refreshAllConnections + await mcpHub.refreshAllConnections() + + // Verify that servers are tracked but not connected + const server = mcpHub.connections.find((conn) => conn.server.name === "refresh-test-server") + expect(server).toBeDefined() + expect(server!.server.status).toBe("disconnected") + expect(server!.client).toBeNull() + expect(server!.transport).toBeNull() + + // Verify postMessageToWebview was called to update the UI + expect(disabledMockProvider.postMessageToWebview).toHaveBeenCalledWith( + expect.objectContaining({ + type: "mcpServers", + }), + ) + }) + + it("should skip restarting connection when MCP is disabled", async () => { + // Mock provider with mcpEnabled: false + const disabledMockProvider = { + ensureSettingsDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + ensureMcpServersDirectoryExists: vi.fn().mockResolvedValue("/mock/settings/path"), + postMessageToWebview: vi.fn(), + getState: vi.fn().mockResolvedValue({ mcpEnabled: false }), + context: mockProvider.context, + } + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "restart-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub with disabled MCP + const mcpHub = new McpHub(disabledMockProvider as unknown as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Set isConnecting to false to ensure it's properly reset + mcpHub.isConnecting = false + + // Try to restart a connection + await mcpHub.restartConnection("restart-test-server") + + // Verify that isConnecting was reset to false + expect(mcpHub.isConnecting).toBe(false) + + // Verify that the server remains disconnected + const server = mcpHub.connections.find((conn) => conn.server.name === "restart-test-server") + expect(server).toBeDefined() + expect(server!.server.status).toBe("disconnected") + expect(server!.client).toBeNull() + expect(server!.transport).toBeNull() + }) }) describe("Windows command wrapping", () => { diff --git a/webview-ui/src/components/mcp/McpView.tsx b/webview-ui/src/components/mcp/McpView.tsx index 39a537f222..21ad1c2652 100644 --- a/webview-ui/src/components/mcp/McpView.tsx +++ b/webview-ui/src/components/mcp/McpView.tsx @@ -206,6 +206,9 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM return configTimeout ?? 60 // Default 1 minute (60 seconds) }) + // Computed property to check if server is expandable + const isExpandable = server.status === "connected" && !server.disabled + const timeoutOptions = [ { value: 15, label: t("mcp:networkTimeout.options.15seconds") }, { value: 30, label: t("mcp:networkTimeout.options.30seconds") }, @@ -235,7 +238,7 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM const handleRowClick = () => { // Only allow expansion for connected and enabled servers - if (server.status === "connected" && !server.disabled) { + if (isExpandable) { setIsExpanded(!isExpanded) } } @@ -276,13 +279,12 @@ const ServerRow = ({ server, alwaysAllowMcp }: { server: McpServer; alwaysAllowM alignItems: "center", padding: "8px", background: "var(--vscode-textCodeBlock-background)", - cursor: server.status === "connected" && !server.disabled ? "pointer" : "default", - borderRadius: - isExpanded || (server.status === "connected" && !server.disabled) ? "4px" : "4px 4px 0 0", + cursor: isExpandable ? "pointer" : "default", + borderRadius: isExpanded || isExpandable ? "4px" : "4px 4px 0 0", opacity: server.disabled ? 0.6 : 1, }} onClick={handleRowClick}> - {server.status === "connected" && !server.disabled && ( + {isExpandable && (
- {server.status === "connected" && !server.disabled + {isExpandable ? isExpanded && (
Date: Thu, 24 Jul 2025 14:54:02 -0500 Subject: [PATCH 5/6] feat: implement discriminated union types for MCP connections and enhance connection handling --- src/services/mcp/McpHub.ts | 73 ++- src/services/mcp/__tests__/McpHub.spec.ts | 732 ++++++++++++++++++++-- 2 files changed, 736 insertions(+), 69 deletions(-) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index cb6783fb7a..f862f4a8d9 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -33,10 +33,27 @@ import { fileExistsAtPath } from "../../utils/fs" import { arePathsEqual } from "../../utils/path" import { injectVariables } from "../../utils/config" -export type McpConnection = { +// Discriminated union for connection states +export type ConnectedMcpConnection = { + type: "connected" server: McpServer - client: Client | null - transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport | null + client: Client + transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport +} + +export type DisconnectedMcpConnection = { + type: "disconnected" + server: McpServer + client: null + transport: null +} + +export type McpConnection = ConnectedMcpConnection | DisconnectedMcpConnection + +// Enum for disable reasons +export enum DisableReason { + MCP_DISABLED = "mcpDisabled", + SERVER_DISABLED = "serverDisabled", } // Base configuration schema for common settings @@ -559,20 +576,21 @@ export class McpHub { * @param config The server configuration * @param source The source of the server (global or project) * @param reason The reason for creating a placeholder (mcpDisabled or serverDisabled) - * @returns A placeholder McpConnection object + * @returns A placeholder DisconnectedMcpConnection object */ private createPlaceholderConnection( name: string, config: z.infer, source: "global" | "project", - reason: "mcpDisabled" | "serverDisabled", - ): McpConnection { + reason: DisableReason, + ): DisconnectedMcpConnection { return { + type: "disconnected", server: { name, config: JSON.stringify(config), status: "disconnected", - disabled: reason === "serverDisabled" ? true : config.disabled, + disabled: reason === DisableReason.SERVER_DISABLED ? true : config.disabled, source, projectPath: source === "project" ? vscode.workspace.workspaceFolders?.[0]?.uri.fsPath : undefined, errorHistory: [], @@ -607,7 +625,7 @@ export class McpHub { const mcpEnabled = await this.isMcpEnabled() if (!mcpEnabled) { // Still create a connection object to track the server, but don't actually connect - const connection = this.createPlaceholderConnection(name, config, source, "mcpDisabled") + const connection = this.createPlaceholderConnection(name, config, source, DisableReason.MCP_DISABLED) this.connections.push(connection) return } @@ -615,11 +633,14 @@ export class McpHub { // Skip connecting to disabled servers if (config.disabled) { // Still create a connection object to track the server, but don't actually connect - const connection = this.createPlaceholderConnection(name, config, source, "serverDisabled") + const connection = this.createPlaceholderConnection(name, config, source, DisableReason.SERVER_DISABLED) this.connections.push(connection) return } + // Set up file watchers for enabled servers + this.setupFileWatcher(name, config, source) + try { const client = new Client( { @@ -793,7 +814,9 @@ export class McpHub { transport.start = async () => {} } - const connection: McpConnection = { + // Create a connected connection + const connection: ConnectedMcpConnection = { + type: "connected", server: { name, config: JSON.stringify(configInjected), @@ -886,8 +909,8 @@ export class McpHub { // Use the helper method to find the connection const connection = this.findConnection(serverName, source) - if (!connection || !connection.client) { - throw new Error(`Server ${serverName} not found or not connected`) + if (!connection || connection.type !== "connected") { + return [] } const response = await connection.client.request({ method: "tools/list" }, ListToolsResultSchema) @@ -941,7 +964,7 @@ export class McpHub { private async fetchResourcesList(serverName: string, source?: "global" | "project"): Promise { try { const connection = this.findConnection(serverName, source) - if (!connection || !connection.client) { + if (!connection || connection.type !== "connected") { return [] } const response = await connection.client.request({ method: "resources/list" }, ListResourcesResultSchema) @@ -958,7 +981,7 @@ export class McpHub { ): Promise { try { const connection = this.findConnection(serverName, source) - if (!connection || !connection.client) { + if (!connection || connection.type !== "connected") { return [] } const response = await connection.client.request( @@ -973,6 +996,9 @@ export class McpHub { } async deleteConnection(name: string, source?: "global" | "project"): Promise { + // Clean up file watchers for this server + this.removeFileWatchersForServer(name) + // If source is provided, only delete connections from that source const connections = source ? this.connections.filter((conn) => conn.server.name === name && conn.server.source === source) @@ -980,10 +1006,8 @@ export class McpHub { for (const connection of connections) { try { - if (connection.transport) { + if (connection.type === "connected") { await connection.transport.close() - } - if (connection.client) { await connection.client.close() } } catch (error) { @@ -1136,6 +1160,14 @@ export class McpHub { this.fileWatchers.clear() } + private removeFileWatchersForServer(serverName: string) { + const watchers = this.fileWatchers.get(serverName) + if (watchers) { + watchers.forEach((watcher) => watcher.close()) + this.fileWatchers.delete(serverName) + } + } + async restartConnection(serverName: string, source?: "global" | "project"): Promise { this.isConnecting = true @@ -1349,6 +1381,8 @@ export class McpHub { // If disabling a connected server, disconnect it if (disabled && connection.server.status === "connected") { + // Clean up file watchers when disabling + this.removeFileWatchersForServer(serverName) await this.deleteConnection(serverName, serverSource) // Re-add as a disabled connection await this.connectToServer(serverName, JSON.parse(connection.server.config), serverSource) @@ -1356,6 +1390,7 @@ export class McpHub { // If enabling a disabled server, connect it const config = JSON.parse(connection.server.config) await this.deleteConnection(serverName, serverSource) + // When re-enabling, file watchers will be set up in connectToServer await this.connectToServer(serverName, config, serverSource) } else if (connection.server.status === "connected") { // Only refresh capabilities if connected @@ -1539,7 +1574,7 @@ export class McpHub { async readResource(serverName: string, uri: string, source?: "global" | "project"): Promise { const connection = this.findConnection(serverName, source) - if (!connection || !connection.client) { + if (!connection || connection.type !== "connected") { throw new Error(`No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}`) } if (connection.server.disabled) { @@ -1563,7 +1598,7 @@ export class McpHub { source?: "global" | "project", ): Promise { const connection = this.findConnection(serverName, source) - if (!connection || !connection.client) { + if (!connection || connection.type !== "connected") { throw new Error( `No connection found for server: ${serverName}${source ? ` with source ${source}` : ""}. Please make sure to use MCP servers available under 'Connected MCP Servers'.`, ) diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 737e7a2907..ebce2d5b2a 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -1,7 +1,7 @@ -import type { McpHub as McpHubType, McpConnection } from "../McpHub" +import type { McpHub as McpHubType, McpConnection, ConnectedMcpConnection, DisconnectedMcpConnection } from "../McpHub" import type { ClineProvider } from "../../../core/webview/ClineProvider" import type { ExtensionContext, Uri } from "vscode" -import { ServerConfigSchema, McpHub } from "../McpHub" +import { ServerConfigSchema, McpHub, DisableReason } from "../McpHub" import fs from "fs/promises" import { vi, Mock } from "vitest" @@ -33,11 +33,15 @@ vi.mock("fs/promises", () => ({ mkdir: vi.fn().mockResolvedValue(undefined), })) +// Import safeWriteJson to use in mocks +import { safeWriteJson } from "../../../utils/safeWriteJson" + // Mock safeWriteJson vi.mock("../../../utils/safeWriteJson", () => ({ safeWriteJson: vi.fn(async (filePath, data) => { // Instead of trying to write to the file system, just call fs.writeFile mock // This avoids the complex file locking and temp file operations + const fs = await import("fs/promises") return fs.writeFile(filePath, JSON.stringify(data), "utf8") }), })) @@ -79,6 +83,16 @@ vi.mock("@modelcontextprotocol/sdk/client/index.js", () => ({ Client: vi.fn(), })) +// Mock chokidar +vi.mock("chokidar", () => ({ + default: { + watch: vi.fn().mockReturnValue({ + on: vi.fn().mockReturnThis(), + close: vi.fn(), + }), + }, +})) + describe("McpHub", () => { let mcpHub: McpHubType let mockProvider: Partial @@ -168,6 +182,587 @@ describe("McpHub", () => { } }) + describe("Discriminated union type handling", () => { + it("should create connected connections with proper type", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Mock the config file read + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "union-test-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + // Create McpHub and let it initialize + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "union-test-server") + expect(connection).toBeDefined() + + // Type guard check - connected connections should have client and transport + if (connection && connection.type === "connected") { + expect(connection.client).toBeDefined() + expect(connection.transport).toBeDefined() + expect(connection.server.status).toBe("connected") + } else { + throw new Error("Connection should be of type 'connected'") + } + }) + + it("should create disconnected connections for disabled servers", async () => { + // Mock the config file read with a disabled server + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-union-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) + + // Create McpHub and let it initialize + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-union-server") + expect(connection).toBeDefined() + + // Type guard check - disconnected connections should have null client and transport + if (connection && connection.type === "disconnected") { + expect(connection.client).toBeNull() + expect(connection.transport).toBeNull() + expect(connection.server.status).toBe("disconnected") + expect(connection.server.disabled).toBe(true) + } else { + throw new Error("Connection should be of type 'disconnected'") + } + }) + + it("should handle type narrowing correctly in callTool", async () => { + // Mock fs.readFile to return empty config so no servers are initialized + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: {}, + }), + ) + + // Create a mock McpHub instance + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Clear any connections that might have been created + mcpHub.connections = [] + + // Directly set up a connected connection + const connectedConnection: ConnectedMcpConnection = { + type: "connected", + server: { + name: "test-server", + config: JSON.stringify({ command: "node", args: ["test.js"] }), + status: "connected", + source: "global", + errorHistory: [], + } as any, + client: { + request: vi.fn().mockResolvedValue({ result: "success" }), + } as any, + transport: {} as any, + } + + // Add the connected connection + mcpHub.connections = [connectedConnection] + + // Call tool should work with connected server + const result = await mcpHub.callTool("test-server", "test-tool", {}) + expect(result).toEqual({ result: "success" }) + expect(connectedConnection.client.request).toHaveBeenCalled() + + // Now test with a disconnected connection + const disconnectedConnection: DisconnectedMcpConnection = { + type: "disconnected", + server: { + name: "disabled-server", + config: JSON.stringify({ command: "node", args: ["test.js"], disabled: true }), + status: "disconnected", + disabled: true, + source: "global", + errorHistory: [], + } as any, + client: null, + transport: null, + } + + // Replace connections with disconnected one + mcpHub.connections = [disconnectedConnection] + + // Call tool should fail with disconnected server + await expect(mcpHub.callTool("disabled-server", "test-tool", {})).rejects.toThrow( + "No connection found for server: disabled-server", + ) + }) + }) + + describe("File watcher cleanup", () => { + it("should clean up file watchers when server is disabled", async () => { + // Get the mocked chokidar + const chokidar = (await import("chokidar")).default + const mockWatcher = { + on: vi.fn().mockReturnThis(), + close: vi.fn(), + } + vi.mocked(chokidar.watch).mockReturnValue(mockWatcher as any) + + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Create server with watchPaths + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "watcher-test-server": { + command: "node", + args: ["test.js"], + watchPaths: ["/path/to/watch"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify watcher was created + expect(chokidar.watch).toHaveBeenCalledWith(["/path/to/watch"], expect.any(Object)) + + // Now disable the server + await mcpHub.toggleServerDisabled("watcher-test-server", true) + + // Verify watcher was closed + expect(mockWatcher.close).toHaveBeenCalled() + }) + + it("should clean up all file watchers when server is deleted", async () => { + // Get the mocked chokidar + const chokidar = (await import("chokidar")).default + const mockWatcher1 = { + on: vi.fn().mockReturnThis(), + close: vi.fn(), + } + const mockWatcher2 = { + on: vi.fn().mockReturnThis(), + close: vi.fn(), + } + + // Return different watchers for different paths + let watcherIndex = 0 + vi.mocked(chokidar.watch).mockImplementation(() => { + return (watcherIndex++ === 0 ? mockWatcher1 : mockWatcher2) as any + }) + + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + // Create server with multiple watchPaths + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "multi-watcher-server": { + command: "node", + args: ["test.js", "build/index.js"], // This will create a watcher for build/index.js + watchPaths: ["/path/to/watch1", "/path/to/watch2"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify watchers were created + expect(chokidar.watch).toHaveBeenCalled() + + // Delete the connection (this should clean up all watchers) + await mcpHub.deleteConnection("multi-watcher-server") + + // Verify all watchers were closed + expect(mockWatcher1.close).toHaveBeenCalled() + expect(mockWatcher2.close).toHaveBeenCalled() + }) + + it("should not create file watchers for disabled servers on initialization", async () => { + // Get the mocked chokidar + const chokidar = (await import("chokidar")).default + + // Create disabled server with watchPaths + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-watcher-server": { + command: "node", + args: ["test.js"], + watchPaths: ["/path/to/watch"], + disabled: true, + }, + }, + }), + ) + + vi.mocked(chokidar.watch).mockClear() + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Verify no watcher was created for disabled server + expect(chokidar.watch).not.toHaveBeenCalled() + }) + }) + + describe("DisableReason enum usage", () => { + it("should use MCP_DISABLED reason when MCP is globally disabled", async () => { + // Mock provider with mcpEnabled: false + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false }) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "mcp-disabled-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "mcp-disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.status).toBe("disconnected") + + // The server should not be marked as disabled individually + expect(connection?.server.disabled).toBeUndefined() + }) + + it("should use SERVER_DISABLED reason when server is individually disabled", async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "server-disabled-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "server-disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.status).toBe("disconnected") + expect(connection?.server.disabled).toBe(true) + }) + + it("should handle both disable reasons correctly", async () => { + // First test with MCP globally disabled + mockProvider.getState = vi.fn().mockResolvedValue({ mcpEnabled: false }) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "both-reasons-server": { + command: "node", + args: ["test.js"], + disabled: true, // Server is also individually disabled + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Find the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "both-reasons-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + + // When MCP is globally disabled, it takes precedence + // The server's individual disabled state should be preserved + expect(connection?.server.disabled).toBe(true) + }) + }) + + describe("Null safety improvements", () => { + it("should handle null client safely in disconnected connections", async () => { + // Mock fs.readFile to return a disabled server config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "null-safety-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // The server should be created as a disconnected connection with null client/transport + const connection = mcpHub.connections.find((conn) => conn.server.name === "null-safety-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + + // Type guard to ensure it's a disconnected connection + if (connection?.type === "disconnected") { + expect(connection.client).toBeNull() + expect(connection.transport).toBeNull() + } + + // Try to call tool on disconnected server + await expect(mcpHub.callTool("null-safety-server", "test-tool", {})).rejects.toThrow( + "No connection found for server: null-safety-server", + ) + + // Try to read resource on disconnected server + await expect(mcpHub.readResource("null-safety-server", "test-uri")).rejects.toThrow( + "No connection found for server: null-safety-server", + ) + }) + + it("should handle connection type checks safely", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "type-check-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Get the connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "type-check-server") + expect(connection).toBeDefined() + + // Safe type checking + if (connection?.type === "connected") { + expect(connection.client).toBeDefined() + expect(connection.transport).toBeDefined() + } else if (connection?.type === "disconnected") { + expect(connection.client).toBeNull() + expect(connection.transport).toBeNull() + } + }) + + it("should handle missing connections safely", async () => { + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Try operations on non-existent server + await expect(mcpHub.callTool("non-existent-server", "test-tool", {})).rejects.toThrow( + "No connection found for server: non-existent-server", + ) + + await expect(mcpHub.readResource("non-existent-server", "test-uri")).rejects.toThrow( + "No connection found for server: non-existent-server", + ) + }) + + it("should handle connection deletion safely", async () => { + // Mock StdioClientTransport + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + stderr: { + on: vi.fn(), + }, + onerror: null, + onclose: null, + } + + StdioClientTransport.mockImplementation(() => mockTransport) + + // Mock Client + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + const Client = clientModule.Client as ReturnType + + const mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + } + + Client.mockImplementation(() => mockClient) + + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "delete-safety-server": { + command: "node", + args: ["test.js"], + }, + }, + }), + ) + + const mcpHub = new McpHub(mockProvider as ClineProvider) + await new Promise((resolve) => setTimeout(resolve, 100)) + + // Delete the connection + await mcpHub.deleteConnection("delete-safety-server") + + // Verify connection is removed + const connection = mcpHub.connections.find((conn) => conn.server.name === "delete-safety-server") + expect(connection).toBeUndefined() + + // Verify transport and client were closed + expect(mockTransport.close).toHaveBeenCalled() + expect(mockClient.close).toHaveBeenCalled() + }) + }) + describe("toggleToolAlwaysAllow", () => { it("should add tool to always allow list when enabling", async () => { const mockConfig = { @@ -185,7 +780,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection without alwaysAllow - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -233,7 +829,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -281,7 +878,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -326,7 +924,8 @@ describe("McpHub", () => { } // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: "test-server-config", @@ -373,7 +972,8 @@ describe("McpHub", () => { } // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: "test-server-config", @@ -419,7 +1019,8 @@ describe("McpHub", () => { } // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: "test-server-config", @@ -469,7 +1070,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -501,6 +1103,7 @@ describe("McpHub", () => { it("should filter out disabled servers from getServers", () => { const mockConnections: McpConnection[] = [ { + type: "connected", server: { name: "enabled-server", config: "{}", @@ -509,17 +1112,18 @@ describe("McpHub", () => { }, client: {} as any, transport: {} as any, - }, + } as ConnectedMcpConnection, { + type: "disconnected", server: { name: "disabled-server", config: "{}", - status: "connected", + status: "disconnected", disabled: true, }, - client: {} as any, - transport: {} as any, - }, + client: null, + transport: null, + } as DisconnectedMcpConnection, ] mcpHub.connections = mockConnections @@ -530,44 +1134,64 @@ describe("McpHub", () => { }) it("should prevent calling tools on disabled servers", async () => { - const mockConnection: McpConnection = { - server: { - name: "disabled-server", - config: "{}", - status: "connected", - disabled: true, - }, - client: { - request: vi.fn().mockResolvedValue({ result: "success" }), - } as any, - transport: {} as any, - } + // Mock fs.readFile to return a disabled server config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) - mcpHub.connections = [mockConnection] + const mcpHub = new McpHub(mockProvider as ClineProvider) + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // The server should be created as a disconnected connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.disabled).toBe(true) + + // Try to call tool on disabled server await expect(mcpHub.callTool("disabled-server", "some-tool", {})).rejects.toThrow( - 'Server "disabled-server" is disabled and cannot be used', + "No connection found for server: disabled-server", ) }) it("should prevent reading resources from disabled servers", async () => { - const mockConnection: McpConnection = { - server: { - name: "disabled-server", - config: "{}", - status: "connected", - disabled: true, - }, - client: { - request: vi.fn(), - } as any, - transport: {} as any, - } + // Mock fs.readFile to return a disabled server config + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "disabled-server": { + command: "node", + args: ["test.js"], + disabled: true, + }, + }, + }), + ) - mcpHub.connections = [mockConnection] + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Wait for initialization + await new Promise((resolve) => setTimeout(resolve, 100)) + + // The server should be created as a disconnected connection + const connection = mcpHub.connections.find((conn) => conn.server.name === "disabled-server") + expect(connection).toBeDefined() + expect(connection?.type).toBe("disconnected") + expect(connection?.server.disabled).toBe(true) + // Try to read resource from disabled server await expect(mcpHub.readResource("disabled-server", "some/uri")).rejects.toThrow( - 'Server "disabled-server" is disabled', + "No connection found for server: disabled-server", ) }) }) @@ -575,7 +1199,8 @@ describe("McpHub", () => { describe("callTool", () => { it("should execute tool successfully", async () => { // Mock the connection with a minimal client implementation - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({}), @@ -638,7 +1263,8 @@ describe("McpHub", () => { }) it("should use default timeout of 60 seconds if not specified", async () => { - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({ type: "stdio", command: "test" }), // No timeout specified @@ -661,7 +1287,8 @@ describe("McpHub", () => { }) it("should apply configured timeout to tool calls", async () => { - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({ type: "stdio", command: "test", timeout: 120 }), // 2 minutes @@ -701,7 +1328,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -746,7 +1374,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection before updating - const mockConnectionInitial: McpConnection = { + const mockConnectionInitial: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -769,7 +1398,8 @@ describe("McpHub", () => { expect(fs.writeFile).toHaveBeenCalled() // Setup connection with invalid timeout - const mockConnectionInvalid: McpConnection = { + const mockConnectionInvalid: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", config: JSON.stringify({ @@ -814,7 +1444,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", @@ -853,7 +1484,8 @@ describe("McpHub", () => { vi.mocked(fs.readFile).mockResolvedValueOnce(JSON.stringify(mockConfig)) // Set up mock connection - const mockConnection: McpConnection = { + const mockConnection: ConnectedMcpConnection = { + type: "connected", server: { name: "test-server", type: "stdio", From 7485066b37b4e9028799b1377a3d235cbd0536b9 Mon Sep 17 00:00:00 2001 From: Daniel Riccio Date: Thu, 24 Jul 2025 15:31:29 -0500 Subject: [PATCH 6/6] feat: delegate MCP enable/disable logic to McpHub and streamline connection management --- src/core/webview/webviewMessageHandler.ts | 53 ++------------------- src/services/mcp/McpHub.ts | 58 +++++++++++++++++++++++ 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 2ee2ee3485..016fecc97f 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -881,57 +881,10 @@ export const webviewMessageHandler = async ( const mcpEnabled = message.bool ?? true await updateGlobalState("mcpEnabled", mcpEnabled) - // If MCP is being disabled, disconnect all servers + // Delegate MCP enable/disable logic to McpHub const mcpHubInstance = provider.getMcpHub() - if (!mcpEnabled && mcpHubInstance) { - // Disconnect all existing connections with error handling - const existingConnections = [...mcpHubInstance.connections] - const disconnectionErrors: Array<{ serverName: string; error: string }> = [] - - for (const conn of existingConnections) { - try { - await mcpHubInstance.deleteConnection(conn.server.name, conn.server.source) - } catch (error) { - const errorMessage = error instanceof Error ? error.message : String(error) - disconnectionErrors.push({ - serverName: conn.server.name, - error: errorMessage, - }) - provider.log(`Failed to disconnect MCP server ${conn.server.name}: ${errorMessage}`) - } - } - - // If there were errors, notify the user - if (disconnectionErrors.length > 0) { - const errorSummary = disconnectionErrors.map((e) => `${e.serverName}: ${e.error}`).join("\n") - vscode.window.showWarningMessage( - t("mcp:errors.disconnect_servers_partial", { - count: disconnectionErrors.length, - errors: errorSummary, - }) || - `Failed to disconnect ${disconnectionErrors.length} MCP server(s). Check the output for details.`, - ) - } - - // Re-initialize servers to track them in disconnected state - try { - await mcpHubInstance.refreshAllConnections() - } catch (error) { - provider.log(`Failed to refresh MCP connections after disabling: ${error}`) - vscode.window.showErrorMessage( - t("mcp:errors.refresh_after_disable") || "Failed to refresh MCP connections after disabling", - ) - } - } else if (mcpEnabled && mcpHubInstance) { - // If MCP is being enabled, reconnect all servers - try { - await mcpHubInstance.refreshAllConnections() - } catch (error) { - provider.log(`Failed to refresh MCP connections after enabling: ${error}`) - vscode.window.showErrorMessage( - t("mcp:errors.refresh_after_enable") || "Failed to refresh MCP connections after enabling", - ) - } + if (mcpHubInstance) { + await mcpHubInstance.handleMcpEnabledChange(mcpEnabled) } await provider.postStateToWebview() diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index f862f4a8d9..6d512b3f28 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -1744,6 +1744,64 @@ export class McpHub { } } + /** + * Handles enabling/disabling MCP globally + * @param enabled Whether MCP should be enabled or disabled + * @returns Promise + */ + async handleMcpEnabledChange(enabled: boolean): Promise { + if (!enabled) { + // If MCP is being disabled, disconnect all servers with error handling + const existingConnections = [...this.connections] + const disconnectionErrors: Array<{ serverName: string; error: string }> = [] + + for (const conn of existingConnections) { + try { + await this.deleteConnection(conn.server.name, conn.server.source) + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + disconnectionErrors.push({ + serverName: conn.server.name, + error: errorMessage, + }) + console.error(`Failed to disconnect MCP server ${conn.server.name}: ${errorMessage}`) + } + } + + // If there were errors, notify the user + if (disconnectionErrors.length > 0) { + const errorSummary = disconnectionErrors.map((e) => `${e.serverName}: ${e.error}`).join("\n") + vscode.window.showWarningMessage( + t("mcp:errors.disconnect_servers_partial", { + count: disconnectionErrors.length, + errors: errorSummary, + }) || + `Failed to disconnect ${disconnectionErrors.length} MCP server(s). Check the output for details.`, + ) + } + + // Re-initialize servers to track them in disconnected state + try { + await this.refreshAllConnections() + } catch (error) { + console.error(`Failed to refresh MCP connections after disabling: ${error}`) + vscode.window.showErrorMessage( + t("mcp:errors.refresh_after_disable") || "Failed to refresh MCP connections after disabling", + ) + } + } else { + // If MCP is being enabled, reconnect all servers + try { + await this.refreshAllConnections() + } catch (error) { + console.error(`Failed to refresh MCP connections after enabling: ${error}`) + vscode.window.showErrorMessage( + t("mcp:errors.refresh_after_enable") || "Failed to refresh MCP connections after enabling", + ) + } + } + } + async dispose(): Promise { // Prevent multiple disposals if (this.isDisposed) {