Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions lib/srv/mcp/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,31 @@ func makeExecServerRunner(ctx context.Context, session *sessionHandler) (stdioSe
// WaitDelay forces a SIGKILL if the process fails to exit 10 seconds after
// cmd.Cancel is called. See the WaitDelay doc for details.
cmd.WaitDelay = 10 * time.Second

// We put all shutdown procedures in cmd.Cancel because we are too lazy to
// make a separate function. Since cmd.Cancel can be called outside here by
// the server handler, we make sure 'cmdCancel' is called to cancel the
// command in that case.
cmd.Cancel = sync.OnceValue(func() error {
//
// There is also a race where cmd.Cancel may be called before cmd.Process is
// available. To make sure the signal is sent, do not use sync.Once on the
// entire "cmd.Cancel". Instead, use sync.Once after cmd.Process non-nil
// check so we have a better chance to clean up.
signalOnce := sync.OnceValue(func() error {
// Use SIGINT for graceful shutdown since stdio servers are
Comment thread
Tener marked this conversation as resolved.
// "interactive".
logger.DebugContext(ctx, "Sending SIGINT to command")
return trace.Wrap(cmd.Process.Signal(syscall.SIGINT))
})
cmd.Cancel = func() error {
logger.DebugContext(ctx, "Canceling command", "has_process", cmd.Process == nil)
cmdCancel()

if cmd.Process != nil {
// Use SIGINT for graceful shutdown since stdio servers are
// "interactive".
logger.DebugContext(ctx, "Sending SIGINT to command")
return trace.Wrap(cmd.Process.Signal(syscall.SIGINT))
return trace.Wrap(signalOnce())
}
return nil
})
}

// Set host user.
hostUser, err := user.Lookup(mcpSpec.RunAsHostUser)
Expand Down
45 changes: 36 additions & 9 deletions lib/srv/mcp/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package mcp

import (
"context"
"math/rand/v2"
"os"
"path"
"testing"
Expand Down Expand Up @@ -213,6 +214,30 @@ func TestHandleSession_execMCPServer(t *testing.T) {
afterHandlerStart: connectAfterHandlerStart,
afterHandlerStop: containerShouldBeRemoved,
},
{
// Randomly cancel the context to simulate a case where client
// disconnects while cmd is being set up, which may cause a race
// condition that leaves docker container behind:
// https://github.com/gravitational/teleport/issues/59768
//
// To restore the bug, use `sync.OnceValue` when creating the func for `cmd.Cancel`:
// cmd.Cancel = sync.OnceValue(func() error {
//
// Run this test with -count=100 for better coverage.
name: "random cancel handler context",
cmd: "docker",
dockerRunArgs: []string{"mcp/everything"},
checkHandlerError: func(require.TestingT, error, ...interface{}) {
// Depends on the timing, this can return error or nil. So just ignore.
},
cancelHandlerCtx: true,
waitForHandlerExit: time.Second * 15,
afterHandlerStart: func(t *testing.T, testCtx *testContext, containerName string) {
time.Sleep(time.Duration(rand.Uint32N(10000)) * time.Microsecond)
},
// Make sure the container is removed no matter the timing.
afterHandlerStop: containerShouldBeRemoved,
},
{
// Make sure handler is not blocked when command fails to start.
name: "fail to start",
Expand Down Expand Up @@ -246,12 +271,16 @@ func TestHandleSession_execMCPServer(t *testing.T) {
`trap "" INT; while :; do sleep 1; done`,
},
checkHandlerError: require.Error,
afterHandlerStart: func(t *testing.T, testCtx *testContext, _ string) {
// Trigger shutdown.
afterHandlerStart: func(t *testing.T, testCtx *testContext, containerName string) {
ctx := t.Context()
t.Log("waiting for docker container to spawn before killing client connection")
require.EventuallyWithT(t, func(t *assert.CollectT) {
require.NotEmpty(t, findDockerContainerID(ctx, dockerClient, containerName))
}, time.Second*5, time.Millisecond*100)
testCtx.clientSourceConn.Close()
t.Log("waiting 10 seconds for SIGKILL")
},
waitForHandlerExit: time.Second * 15,
waitForHandlerExit: time.Second * 20,
},
}

Expand Down Expand Up @@ -281,12 +310,9 @@ func TestHandleSession_execMCPServer(t *testing.T) {
testCtx := setupTestContext(t, withAdminRole(t), withApp(app))
handlerCtx, handlerCtxCancel := context.WithCancel(t.Context())
defer handlerCtxCancel()
handlerDoneCh := make(chan struct{}, 1)
defer close(handlerDoneCh)
handlerErrChan := make(chan error, 1)
go func() {
handlerErr := s.HandleSession(handlerCtx, testCtx.SessionCtx)
handlerDoneCh <- struct{}{}
tt.checkHandlerError(t, handlerErr)
handlerErrChan <- s.HandleSession(handlerCtx, testCtx.SessionCtx)
}()

if tt.afterHandlerStart != nil {
Expand All @@ -299,7 +325,8 @@ func TestHandleSession_execMCPServer(t *testing.T) {
select {
case <-time.After(tt.waitForHandlerExit):
require.Fail(t, "timed out waiting for handler")
case <-handlerDoneCh:
case handlerErr := <-handlerErrChan:
tt.checkHandlerError(t, handlerErr)
}

if tt.afterHandlerStop != nil {
Expand Down
Loading