diff --git a/lib/srv/mcp/stdio.go b/lib/srv/mcp/stdio.go index d48727ebe8d7f..64a589a2002ea 100644 --- a/lib/srv/mcp/stdio.go +++ b/lib/srv/mcp/stdio.go @@ -214,21 +214,31 @@ func makeExecServerRunner(ctx context.Context, session *sessionHandler) (stdioSe // WaitDelay forces a SIGKILL if the process fails to exit 10 seconds after // cmd.Cancel is called. See the WaitDelay doc for details. cmd.WaitDelay = 10 * time.Second + // We put all shutdown procedures in cmd.Cancel because we are too lazy to // make a separate function. Since cmd.Cancel can be called outside here by // the server handler, we make sure 'cmdCancel' is called to cancel the // command in that case. - cmd.Cancel = sync.OnceValue(func() error { + // + // There is also a race where cmd.Cancel may be called before cmd.Process is + // available. To make sure the signal is sent, do not use sync.Once on the + // entire "cmd.Cancel". Instead, use sync.Once after cmd.Process non-nil + // check so we have a better chance to clean up. + signalOnce := sync.OnceValue(func() error { + // Use SIGINT for graceful shutdown since stdio servers are + // "interactive". + logger.DebugContext(ctx, "Sending SIGINT to command") + return trace.Wrap(cmd.Process.Signal(syscall.SIGINT)) + }) + cmd.Cancel = func() error { + logger.DebugContext(ctx, "Canceling command", "has_process", cmd.Process == nil) cmdCancel() if cmd.Process != nil { - // Use SIGINT for graceful shutdown since stdio servers are - // "interactive". - logger.DebugContext(ctx, "Sending SIGINT to command") - return trace.Wrap(cmd.Process.Signal(syscall.SIGINT)) + return trace.Wrap(signalOnce()) } return nil - }) + } // Set host user. hostUser, err := user.Lookup(mcpSpec.RunAsHostUser) diff --git a/lib/srv/mcp/stdio_test.go b/lib/srv/mcp/stdio_test.go index 6a90af1e09a03..4fe5e6973c6d3 100644 --- a/lib/srv/mcp/stdio_test.go +++ b/lib/srv/mcp/stdio_test.go @@ -20,6 +20,7 @@ package mcp import ( "context" + "math/rand/v2" "os" "path" "testing" @@ -213,6 +214,30 @@ func TestHandleSession_execMCPServer(t *testing.T) { afterHandlerStart: connectAfterHandlerStart, afterHandlerStop: containerShouldBeRemoved, }, + { + // Randomly cancel the context to simulate a case where client + // disconnects while cmd is being set up, which may cause a race + // condition that leaves docker container behind: + // https://github.com/gravitational/teleport/issues/59768 + // + // To restore the bug, use `sync.OnceValue` when creating the func for `cmd.Cancel`: + // cmd.Cancel = sync.OnceValue(func() error { + // + // Run this test with -count=100 for better coverage. + name: "random cancel handler context", + cmd: "docker", + dockerRunArgs: []string{"mcp/everything"}, + checkHandlerError: func(require.TestingT, error, ...interface{}) { + // Depends on the timing, this can return error or nil. So just ignore. + }, + cancelHandlerCtx: true, + waitForHandlerExit: time.Second * 15, + afterHandlerStart: func(t *testing.T, testCtx *testContext, containerName string) { + time.Sleep(time.Duration(rand.Uint32N(10000)) * time.Microsecond) + }, + // Make sure the container is removed no matter the timing. + afterHandlerStop: containerShouldBeRemoved, + }, { // Make sure handler is not blocked when command fails to start. name: "fail to start", @@ -246,12 +271,16 @@ func TestHandleSession_execMCPServer(t *testing.T) { `trap "" INT; while :; do sleep 1; done`, }, checkHandlerError: require.Error, - afterHandlerStart: func(t *testing.T, testCtx *testContext, _ string) { - // Trigger shutdown. + afterHandlerStart: func(t *testing.T, testCtx *testContext, containerName string) { + ctx := t.Context() + t.Log("waiting for docker container to spawn before killing client connection") + require.EventuallyWithT(t, func(t *assert.CollectT) { + require.NotEmpty(t, findDockerContainerID(ctx, dockerClient, containerName)) + }, time.Second*5, time.Millisecond*100) testCtx.clientSourceConn.Close() t.Log("waiting 10 seconds for SIGKILL") }, - waitForHandlerExit: time.Second * 15, + waitForHandlerExit: time.Second * 20, }, } @@ -281,12 +310,9 @@ func TestHandleSession_execMCPServer(t *testing.T) { testCtx := setupTestContext(t, withAdminRole(t), withApp(app)) handlerCtx, handlerCtxCancel := context.WithCancel(t.Context()) defer handlerCtxCancel() - handlerDoneCh := make(chan struct{}, 1) - defer close(handlerDoneCh) + handlerErrChan := make(chan error, 1) go func() { - handlerErr := s.HandleSession(handlerCtx, testCtx.SessionCtx) - handlerDoneCh <- struct{}{} - tt.checkHandlerError(t, handlerErr) + handlerErrChan <- s.HandleSession(handlerCtx, testCtx.SessionCtx) }() if tt.afterHandlerStart != nil { @@ -299,7 +325,8 @@ func TestHandleSession_execMCPServer(t *testing.T) { select { case <-time.After(tt.waitForHandlerExit): require.Fail(t, "timed out waiting for handler") - case <-handlerDoneCh: + case handlerErr := <-handlerErrChan: + tt.checkHandlerError(t, handlerErr) } if tt.afterHandlerStop != nil {