diff --git a/src/client/stdio.ts b/src/client/stdio.ts index d62a3aeb6..f4a243d35 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -134,7 +134,6 @@ export class StdioClientTransport implements Transport { this._process.on('error', error => { if (error.name === 'AbortError') { // Expected when close() is called. - this.onclose?.(); return; } @@ -210,8 +209,36 @@ export class StdioClientTransport implements Transport { } async close(): Promise { - this._abortController.abort(); - this._process = undefined; + if (this._process) { + const processToClose = this._process; + this._process = undefined; + + const closePromise = new Promise(resolve => { + processToClose.once('close', () => { + resolve(); + }); + }); + + this._abortController.abort(); + + // waits the underlying process to exit cleanly otherwise after 1s kills it + await Promise.race([closePromise, new Promise(resolve => setTimeout(resolve, 1_000).unref())]); + + if (processToClose.exitCode === null) { + try { + processToClose.stdin?.end(); + } catch { + // ignore errors in trying to close stdin + } + + try { + processToClose.kill('SIGKILL'); + } catch { + // we did our best + } + } + } + this._readBuffer.clear(); } diff --git a/src/integration-tests/process-cleanup.test.ts b/src/integration-tests/process-cleanup.test.ts index e90ec7e24..1a5b9a35b 100644 --- a/src/integration-tests/process-cleanup.test.ts +++ b/src/integration-tests/process-cleanup.test.ts @@ -1,10 +1,14 @@ +import { Readable, Writable } from 'node:stream'; +import { Client } from '../client/index.js'; +import { StdioClientTransport } from '../client/stdio.js'; import { Server } from '../server/index.js'; import { StdioServerTransport } from '../server/stdio.js'; +import { LoggingMessageNotificationSchema } from '../types.js'; describe('Process cleanup', () => { vi.setConfig({ testTimeout: 5000 }); // 5 second timeout - it('should exit cleanly after closing transport', async () => { + it('server should exit cleanly after closing transport', async () => { const server = new Server( { name: 'test-server', @@ -15,14 +19,92 @@ describe('Process cleanup', () => { } ); - const transport = new StdioServerTransport(); + const mockReadable = new Readable({ + read() { + this.push(null); // signal EOF + } + }), + mockWritable = new Writable({ + write(chunk, encoding, callback) { + callback(); + } + }); + + // Attach mock streams to process for the server transport + const transport = new StdioServerTransport(mockReadable, mockWritable); await server.connect(transport); // Close the transport await transport.close(); + // ensure a proper disposal mock streams + mockReadable.destroy(); + mockWritable.destroy(); + // If we reach here without hanging, the test passes // The test runner will fail if the process hangs expect(true).toBe(true); }); + + it('onclose should be called exactly once', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StdioClientTransport({ + command: 'npm', + args: ['exec', 'tsx', 'test-server.ts'], + cwd: __dirname + }); + + await client.connect(transport); + + let onCloseWasCalled = 0; + client.onclose = () => { + onCloseWasCalled++; + }; + + await client.close(); + + // A short delay to allow the close event to propagate + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(onCloseWasCalled).toBe(1); + }); + + it('should exit cleanly for a server that hangs', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StdioClientTransport({ + command: 'npm', + args: ['exec', 'tsx', 'server-that-hangs.ts'], + cwd: __dirname + }); + + await client.connect(transport); + await client.setLoggingLevel('debug'); + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + console.debug('server log: ' + notification.params.data); + }); + const serverPid = transport.pid!; + + await client.close(); + + // A short delay to allow the close event to propagate + await new Promise(resolve => setTimeout(resolve, 50)); + + try { + process.kill(serverPid, 9); + throw new Error('Expected server to be dead but it is alive'); + } catch (err: unknown) { + // 'ESRCH' the process doesn't exist + if (err && typeof err === 'object' && 'code' in err && err.code === 'ESRCH') { + // success + } else throw err; + } + }); }); diff --git a/src/integration-tests/server-that-hangs.ts b/src/integration-tests/server-that-hangs.ts new file mode 100644 index 000000000..7f82ac3f7 --- /dev/null +++ b/src/integration-tests/server-that-hangs.ts @@ -0,0 +1,45 @@ +import { setTimeout } from 'node:timers'; +import process from 'node:process'; +import { McpServer } from '../server/mcp.js'; +import { StdioServerTransport } from '../server/stdio.js'; + +const transport = new StdioServerTransport(); + +const server = new McpServer( + { + name: 'server-that-hangs', + title: 'Test Server that hangs', + version: '1.0.0' + }, + { + capabilities: { + logging: {} + } + } +); + +await server.connect(transport); + +const doNotExitImmediately = async (signal: NodeJS.Signals) => { + await server.sendLoggingMessage({ + level: 'debug', + data: `received signal ${signal}` + }); + setTimeout(() => process.exit(0), 30 * 1000); +}; + +transport.onclose = () => { + server.sendLoggingMessage({ + level: 'debug', + data: 'transport: onclose called. This should never happen' + }); +}; + +process.stdin.on('close', hadErr => { + server.sendLoggingMessage({ + level: 'debug', + data: 'stdin closed. Error: ' + hadErr + }); +}); +process.on('SIGINT', doNotExitImmediately); +process.on('SIGTERM', doNotExitImmediately); diff --git a/src/integration-tests/test-server.ts b/src/integration-tests/test-server.ts new file mode 100644 index 000000000..6401d0f83 --- /dev/null +++ b/src/integration-tests/test-server.ts @@ -0,0 +1,19 @@ +import { McpServer } from '../server/mcp.js'; +import { StdioServerTransport } from '../server/stdio.js'; + +const transport = new StdioServerTransport(); + +const server = new McpServer({ + name: 'test-server', + version: '1.0.0' +}); + +await server.connect(transport); + +const exit = async () => { + await server.close(); + process.exit(0); +}; + +process.on('SIGINT', exit); +process.on('SIGTERM', exit); diff --git a/tsconfig.cjs.json b/tsconfig.cjs.json index 3b46f11c4..0870184d9 100644 --- a/tsconfig.cjs.json +++ b/tsconfig.cjs.json @@ -5,5 +5,5 @@ "moduleResolution": "node", "outDir": "./dist/cjs" }, - "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*", "src/integration-tests"] } diff --git a/tsconfig.prod.json b/tsconfig.prod.json index fcf2e951c..6eedfd710 100644 --- a/tsconfig.prod.json +++ b/tsconfig.prod.json @@ -3,5 +3,5 @@ "compilerOptions": { "outDir": "./dist/esm" }, - "exclude": ["**/*.test.ts", "src/__mocks__/**/*", "src/server/zodTestMatrix.ts"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*", "src/server/zodTestMatrix.ts", "src/integration-tests"] }