diff --git a/lib/client/api.go b/lib/client/api.go
index 07e04f8b41ef6..9e8bd3b0cb794 100644
--- a/lib/client/api.go
+++ b/lib/client/api.go
@@ -1900,6 +1900,14 @@ type SSHOptions struct {
// machine. If provided, it will be used instead of establishing a connection
// to the target host and executing the command remotely.
LocalCommandExecutor func(string, []string) error
+ // OnChildAuthenticate is a function to run in the child process during
+ // --fork-after authentications. It runs after authentication completes
+ // but before the session begins.
+ OnChildAuthenticate func() error
+}
+
+func (opts SSHOptions) forkAfterAuthentication() bool {
+ return opts.OnChildAuthenticate != nil
}
// WithHostAddress returns a SSHOptions which overrides the
@@ -1918,6 +1926,15 @@ func WithLocalCommandExecutor(executor func(string, []string) error) func(*SSHOp
}
}
+// WithForkAfterAuthentication indicates that tsh is currently reexec-ing
+// for --fork-after-authentication. The given function is called after
+// authentication is complete but before the session starts.
+func WithForkAfterAuthentication(onAuthenticate func() error) func(*SSHOptions) {
+ return func(opt *SSHOptions) {
+ opt.OnChildAuthenticate = onAuthenticate
+ }
+}
+
// SSH connects to a node and, if 'command' is specified, executes the command on it,
// otherwise runs interactive shell
//
@@ -1960,9 +1977,14 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
}
if len(nodeAddrs) > 1 {
+ if options.forkAfterAuthentication() {
+ return &NonRetryableError{
+ Err: trace.BadParameter("fork after authentication not supported for commands on multiple nodes"),
+ }
+ }
return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command)
}
- return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor)
+ return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options)
}
// ConnectToNode attempts to establish a connection to the node resolved to by the provided
@@ -2165,7 +2187,7 @@ func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *Cluster
return nodeClient, trace.Wrap(err)
}
-func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, commandExecutor func(string, []string) error) error {
+func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, options SSHOptions) error {
cluster := clt.ClusterName()
ctx, span := tc.Tracer.Start(
ctx,
@@ -2189,6 +2211,12 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
return trace.Wrap(err)
}
defer nodeClient.Close()
+
+ if options.OnChildAuthenticate != nil {
+ if err := options.OnChildAuthenticate(); err != nil {
+ return trace.Wrap(err)
+ }
+ }
// If forwarding ports were specified, start port forwarding.
if err := tc.startPortForwarding(ctx, nodeClient); err != nil {
return trace.Wrap(err)
@@ -2220,11 +2248,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
// After port forwarding, run a local command that uses the connection, and
// then disconnect.
- if commandExecutor != nil {
+ if options.LocalCommandExecutor != nil {
if len(tc.Config.LocalForwardPorts) == 0 {
fmt.Println("Executing command locally without connecting to any servers. This makes no sense.")
}
- return commandExecutor(tc.Config.HostLogin, command)
+ return options.LocalCommandExecutor(tc.Config.HostLogin, command)
}
if len(command) > 0 {
@@ -2259,7 +2287,7 @@ func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context,
// Issue "shell" request to the first matching node.
fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes match the label selector, picking first: %q\n", nodeAddrs[0])
- return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, nil)
+ return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, SSHOptions{})
}
func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *NodeClient) error {
diff --git a/lib/client/reexec/reexec.go b/lib/client/reexec/reexec.go
new file mode 100644
index 0000000000000..bd382e251d9f7
--- /dev/null
+++ b/lib/client/reexec/reexec.go
@@ -0,0 +1,148 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package reexec
+
+import (
+ "context"
+ "errors"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+
+ "github.com/gravitational/trace"
+)
+
+// NotifyFileSignal signals on the returned channel when the provided file
+// receives a signal (a one-byte read).
+func NotifyFileSignal(f *os.File) <-chan error {
+ errorCh := make(chan error, 1)
+ go func() {
+ n, err := f.Read(make([]byte, 1))
+ if n > 0 {
+ errorCh <- nil
+ } else if err == nil {
+ // this should be impossible according to the io.Reader contract
+ errorCh <- io.ErrUnexpectedEOF
+ } else {
+ errorCh <- err
+ }
+ }()
+ return errorCh
+}
+
+// SignalAndClose writes a byte to the provided file (to signal a caller of
+// NotifyFileSignal) and closes it.
+func SignalAndClose(f *os.File) error {
+ _, err := f.Write([]byte{0x00})
+ return trace.NewAggregate(err, f.Close())
+}
+
+// ForkAuthenticateParams are the parameters to RunForkAuthenticate.
+type ForkAuthenticateParams struct {
+ // GetArgs gets the arguments to re-exec with, excluding the executable
+ // (equivalent to os.Args[1:]).
+ GetArgs func(signalFd, killFd uint64) []string
+ // executable is the executable to run while re-execing. Overridden in tests.
+ executable string
+ // Stdin is the child process' stdin.
+ Stdin io.Reader
+ // Stdout is the child process' stdout.
+ Stdout io.Writer
+ // Stderr is the child process' stderr.
+ Stderr io.Writer
+}
+
+// RunForkAuthenticate re-execs the current executable and waits for any of
+// the following:
+// - The child process exits (usually in error).
+// - The child process signals the parent that it is ready to be disowned.
+// - The context is canceled.
+func RunForkAuthenticate(ctx context.Context, params ForkAuthenticateParams) error {
+ if params.executable == "" {
+ executable, err := getExecutable()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ params.executable = executable
+ }
+ cmd := exec.Command(params.executable)
+ // Set up signal pipes.
+ disownR, disownW, err := os.Pipe()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ killR, killW, err := os.Pipe()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer func() {
+ // If the child is still listening, kill it. If the child successfully
+ // disowned, this will do nothing.
+ SignalAndClose(killW)
+ killR.Close()
+ disownW.Close()
+ disownR.Close()
+ }()
+
+ signalFd, killFd := configureReexecForOS(cmd, disownW, killR)
+ cmd.Args = append(cmd.Args, params.GetArgs(signalFd, killFd)...)
+ cmd.Args[0] = os.Args[0]
+ cmd.Stdin = params.Stdin
+ cmd.Stdout = params.Stdout
+ cmd.Stderr = params.Stderr
+
+ if err := cmd.Start(); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Clean up parent end of pipes.
+ if err := disownW.Close(); err != nil {
+ return trace.NewAggregate(err, killAndWaitProcess(cmd))
+ }
+ if err := killR.Close(); err != nil {
+ return trace.NewAggregate(err, killAndWaitProcess(cmd))
+ }
+
+ select {
+ case err := <-NotifyFileSignal(disownR):
+ if err == nil {
+ return trace.Wrap(cmd.Process.Release())
+ } else if errors.Is(err, io.EOF) {
+ // EOF means the child process exited, no need to report it on top of kill/wait.
+ return trace.Wrap(killAndWaitProcess(cmd))
+ }
+ return trace.NewAggregate(err, killAndWaitProcess(cmd))
+ case <-ctx.Done():
+ return trace.NewAggregate(ctx.Err(), killAndWaitProcess(cmd))
+ }
+}
+
+func killAndWaitProcess(cmd *exec.Cmd) error {
+ if err := cmd.Process.Kill(); err != nil {
+ return trace.Wrap(err)
+ }
+ err := cmd.Wait()
+ var execErr *exec.ExitError
+ if errors.As(err, &execErr) && execErr.ExitCode() != 0 {
+ return trace.Wrap(err)
+ } else if err != nil && strings.Contains(err.Error(), "signal: killed") {
+ // If the process was successfully killed, there is no issue.
+ return nil
+ }
+ return trace.Wrap(err)
+}
diff --git a/lib/client/reexec/reexec_darwin.go b/lib/client/reexec/reexec_darwin.go
new file mode 100644
index 0000000000000..f4ee34d375374
--- /dev/null
+++ b/lib/client/reexec/reexec_darwin.go
@@ -0,0 +1,43 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build darwin
+
+package reexec
+
+import (
+ "os"
+ "os/exec"
+ "syscall"
+
+ "github.com/gravitational/trace"
+)
+
+// getExecutable gets the path to the executable that should be used for re-exec.
+func getExecutable() (string, error) {
+ executable, err := os.Executable()
+ return executable, trace.Wrap(err)
+}
+
+// configureReexecForOS configures the command with files to inherit and
+// os-specific tweaks.
+func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ }
+ cmd.ExtraFiles = []*os.File{signal, kill}
+ return 3, 4
+}
diff --git a/lib/client/reexec/reexec_linux.go b/lib/client/reexec/reexec_linux.go
new file mode 100644
index 0000000000000..e5d843f4ce820
--- /dev/null
+++ b/lib/client/reexec/reexec_linux.go
@@ -0,0 +1,40 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build linux
+
+package reexec
+
+import (
+ "os"
+ "os/exec"
+ "syscall"
+)
+
+// getExecutable gets the path to the executable that should be used for re-exec.
+func getExecutable() (string, error) {
+ return "/proc/self/exe", nil
+}
+
+// configureReexecForOS configures the command with files to inherit and
+// os-specific tweaks.
+func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setsid: true,
+ }
+ cmd.ExtraFiles = []*os.File{signal, kill}
+ return 3, 4
+}
diff --git a/lib/client/reexec/reexec_test.go b/lib/client/reexec/reexec_test.go
new file mode 100644
index 0000000000000..bfdb10a0041bb
--- /dev/null
+++ b/lib/client/reexec/reexec_test.go
@@ -0,0 +1,185 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package reexec
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "os/exec"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/lib/utils/testutils"
+)
+
+type syncBuffer struct {
+ buf *bytes.Buffer
+ mu sync.Mutex
+}
+
+func newSyncBuffer() *syncBuffer {
+ return &syncBuffer{
+ buf: &bytes.Buffer{},
+ }
+}
+
+func (rw *syncBuffer) Read(b []byte) (int, error) {
+ rw.mu.Lock()
+ defer rw.mu.Unlock()
+ return rw.buf.Read(b)
+}
+
+func (rw *syncBuffer) Write(b []byte) (int, error) {
+ rw.mu.Lock()
+ defer rw.mu.Unlock()
+ return rw.buf.Write(b)
+}
+
+func (rw *syncBuffer) String() string {
+ rw.mu.Lock()
+ defer rw.mu.Unlock()
+ return rw.buf.String()
+}
+
+func TestRunForkAuthenticate(t *testing.T) {
+ t.Parallel()
+
+ t.Run("child disowns successfully", func(t *testing.T) {
+ t.Parallel()
+ const script = `
+ read
+ # Close signal fd.
+ echo x >&%d
+ exec %d>&-
+ # stdout/err should still work.
+ echo "stdout: $REPLY"
+ echo "stderr: $REPLY" >&2
+ # Wait to ensure the fd closure is caught before the process ends.
+ sleep 1
+ `
+ getArgs := func(signalFd, killFd uint64) []string {
+ return []string{"-c", fmt.Sprintf(script, signalFd, signalFd)}
+ }
+ stdout := newSyncBuffer()
+ stderr := newSyncBuffer()
+ params := ForkAuthenticateParams{
+ GetArgs: getArgs,
+ executable: "bash",
+ Stdin: bytes.NewBufferString("hello\n"),
+ Stdout: stdout,
+ Stderr: stderr,
+ }
+
+ err := RunForkAuthenticate(t.Context(), params)
+ assert.NoError(t, err)
+ assert.EventuallyWithT(t, func(t *assert.CollectT) {
+ assert.Equal(t, "stdout: hello\n", stdout.String())
+ assert.Equal(t, "stderr: hello\n", stderr.String())
+ }, 10*time.Second, 100*time.Millisecond)
+ })
+
+ t.Run("child exits with error", func(t *testing.T) {
+ t.Parallel()
+ const script = `
+ # Make sure stdin/out/err work.
+ read
+ echo "stdout: $REPLY"
+ echo "stderr: $REPLY" >&2
+ # Exit with error.
+ exit 42
+ `
+ getArgs := func(signalFd, killFd uint64) []string {
+ return []string{"-c", script}
+ }
+ stdout := newSyncBuffer()
+ stderr := newSyncBuffer()
+ params := ForkAuthenticateParams{
+ GetArgs: getArgs,
+ executable: "bash",
+ Stdin: bytes.NewBufferString("hello\n"),
+ Stdout: stdout,
+ Stderr: stderr,
+ }
+
+ err := RunForkAuthenticate(t.Context(), params)
+ var execErr *exec.ExitError
+ if assert.ErrorAs(t, err, &execErr) {
+ assert.Equal(t, 42, execErr.ExitCode())
+ }
+ assert.Equal(t, "stdout: hello\n", stdout.String())
+ assert.Equal(t, "stderr: hello\n", stderr.String())
+ })
+
+ t.Run("context canceled", func(t *testing.T) {
+ t.Parallel()
+ getArgs := func(_, _ uint64) []string {
+ return []string{"-c", `
+ # Make sure stdin/out/err work.
+ read
+ echo "stdout: $REPLY"
+ echo "stderr: $REPLY" >&2
+ # wait for cancellation
+ sleep 2
+ # should not be executed
+ echo "extra output"
+ `}
+ }
+ stdout := newSyncBuffer()
+ stderr := newSyncBuffer()
+ params := ForkAuthenticateParams{
+ GetArgs: getArgs,
+ executable: "bash",
+ Stdin: bytes.NewBufferString("hello\n"),
+ Stdout: stdout,
+ Stderr: stderr,
+ }
+ ctx, cancel := context.WithCancel(t.Context())
+ t.Cleanup(cancel)
+
+ errorCh := make(chan error, 1)
+ testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{
+ Name: "RunForkAuthenticate",
+ Task: func(ctx context.Context) error {
+ errorCh <- RunForkAuthenticate(ctx, params)
+ return nil
+ },
+ })
+
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ assert.Equal(t, "stdout: hello\n", stdout.String())
+ assert.Equal(t, "stderr: hello\n", stderr.String())
+ }, 10*time.Second, 100*time.Millisecond)
+
+ cancel()
+ select {
+ case err := <-errorCh:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "timed out waiting for child to finish")
+ }
+
+ require.Never(t, func() bool {
+ return strings.Contains(stdout.String(), "extra output")
+ }, 3*time.Second, time.Second)
+ })
+}
diff --git a/lib/client/reexec/reexec_windows.go b/lib/client/reexec/reexec_windows.go
new file mode 100644
index 0000000000000..be690592c4dc0
--- /dev/null
+++ b/lib/client/reexec/reexec_windows.go
@@ -0,0 +1,48 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build windows
+
+package reexec
+
+import (
+ "os"
+ "os/exec"
+ "runtime"
+ "syscall"
+
+ "github.com/gravitational/trace"
+)
+
+// getExecutable gets the path to the executable that should be used for re-exec.
+func getExecutable() (string, error) {
+ executable, err := os.Executable()
+ return executable, trace.Wrap(err)
+}
+
+// configureReexecForOS configures the command with files to inherit and
+// os-specific tweaks.
+func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
+ // Prevent handle from being closed when signal is garbage collected.
+ runtime.SetFinalizer(signal, nil)
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ AdditionalInheritedHandles: []syscall.Handle{
+ syscall.Handle(signal.Fd()),
+ syscall.Handle(kill.Fd()),
+ },
+ }
+ return uint64(signal.Fd()), uint64(kill.Fd())
+}
diff --git a/tool/tsh/common/dup2_linux.go b/tool/tsh/common/dup2_linux.go
new file mode 100644
index 0000000000000..52ca8b0585e2c
--- /dev/null
+++ b/tool/tsh/common/dup2_linux.go
@@ -0,0 +1,34 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build linux
+
+package common
+
+import "syscall"
+
+// dup2 implements syscall.Dup2(oldfd, newfd) in a way that works on all
+// current Linux platforms, and likely on any new platforms. New platforms
+// such as ARM64 do not implement syscall.Dup2() instead implementing
+// syscall.Dup3() which is largely a superset, with one special case.
+func dup2(oldfd, newfd int) error {
+ if oldfd == newfd {
+ // dup2 would do nothing in this case, but dup3 returns an error.
+ // Emulate dup2 behavior.
+ return nil
+ }
+ return syscall.Dup3(oldfd, newfd, 0)
+}
diff --git a/tool/tsh/common/dup2_unix.go b/tool/tsh/common/dup2_unix.go
new file mode 100644
index 0000000000000..47aef96884851
--- /dev/null
+++ b/tool/tsh/common/dup2_unix.go
@@ -0,0 +1,28 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build unix && !linux
+
+package common
+
+import "syscall"
+
+// dup2 wraps syscall.Dup2(oldfd, newfd) on non-linux unix platforms. The
+// linux implementation uses syscall.Dup3() as Dup2() is not available
+// on all linux platforms.
+func dup2(oldfd, newfd int) error {
+ return syscall.Dup2(oldfd, newfd)
+}
diff --git a/tool/tsh/common/reexec_unix.go b/tool/tsh/common/reexec_unix.go
new file mode 100644
index 0000000000000..4a707ab8a62ca
--- /dev/null
+++ b/tool/tsh/common/reexec_unix.go
@@ -0,0 +1,69 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build unix
+
+package common
+
+import (
+ "os"
+ "syscall"
+
+ "github.com/gravitational/trace"
+)
+
+func isValidForkSignalFd(fd uint64) bool {
+ // Don't allow stdin, stdout, or stderr.
+ return fd > 2
+}
+
+// newSignalFile creates a signaling file for --fork-after-authentication from
+// a file descriptor.
+func newSignalFile(fd uint64) *os.File {
+ syscall.CloseOnExec(int(fd))
+ return os.NewFile(uintptr(fd), "disown")
+}
+
+// replaceStdin returns a file for /dev/null that should be used from now
+// on instead of stdin.
+func replaceStdin() (*os.File, error) {
+ devNull, err := os.Open(os.DevNull)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ rc, err := devNull.SyscallConn()
+ if err != nil {
+ _ = devNull.Close()
+ return nil, trace.Wrap(err)
+ }
+ var dupErr error
+ if ctrlErr := rc.Control(func(fd uintptr) {
+ dupErr = dup2(int(fd), syscall.Stdin)
+ // dup2() is sufficient here as the three stdio file
+ // descriptors must not be O_CLOEXEC. Darwin does not have
+ // dup3(), so would need to resort to syscall.ForkLock
+ // shenanigans if we did need to set O_CLOEXEC.
+ }); ctrlErr != nil {
+ _ = devNull.Close()
+ return nil, trace.Wrap(ctrlErr)
+ }
+ if dupErr != nil {
+ // this is the error from dup2
+ _ = devNull.Close()
+ return nil, trace.Wrap(err)
+ }
+ return devNull, err
+}
diff --git a/tool/tsh/common/reexec_windows.go b/tool/tsh/common/reexec_windows.go
new file mode 100644
index 0000000000000..bbfc96a077089
--- /dev/null
+++ b/tool/tsh/common/reexec_windows.go
@@ -0,0 +1,45 @@
+// Teleport
+// Copyright (C) 2025 Gravitational, Inc.
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//go:build windows
+
+package common
+
+import (
+ "os"
+ "syscall"
+
+ "github.com/gravitational/trace"
+)
+
+func isValidForkSignalFd(fd uint64) bool {
+ // Don't allow NULL.
+ return fd != 0
+}
+
+// newSignalFile creates a signaling file for --fork-after-authentication from
+// a file descriptor.
+func newSignalFile(fd uint64) *os.File {
+ syscall.CloseOnExec(syscall.Handle(fd))
+ return os.NewFile(uintptr(fd), "disown")
+}
+
+// replaceStdin returns a file for /dev/null that should be used from now
+// on instead of stdin.
+func replaceStdin() (*os.File, error) {
+ devNull, err := os.Open(os.DevNull)
+ return devNull, trace.Wrap(err)
+}
diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go
index 550811255519b..56d680920562c 100644
--- a/tool/tsh/common/tsh.go
+++ b/tool/tsh/common/tsh.go
@@ -80,6 +80,7 @@ import (
"github.com/gravitational/teleport/lib/client"
dbprofile "github.com/gravitational/teleport/lib/client/db"
"github.com/gravitational/teleport/lib/client/identityfile"
+ "github.com/gravitational/teleport/lib/client/reexec"
"github.com/gravitational/teleport/lib/defaults"
dtauthn "github.com/gravitational/teleport/lib/devicetrust/authn"
dtenroll "github.com/gravitational/teleport/lib/devicetrust/enroll"
@@ -618,6 +619,20 @@ type CLIConf struct {
// atomic here is overkill as the CLIConf is generally consumed sequentially. However, occasionally
// we need concurrency safety, such as for [forEachProfileParallel].
clientStoreSet int32
+
+ // ForkAfterAuthentication indicates that tsh should go into the background
+ // after authentication.
+ ForkAfterAuthentication bool
+ // forkSignalFd is the file descriptor for the child process to signal the
+ // parent when re-execing.
+ forkSignalFd uint64
+ // forkKillFd is the file descriptor for the child process to check the
+ // parent's state when re-execing.
+ forkKillFd uint64
+}
+
+func (c *CLIConf) isForkAuthChild() bool {
+ return isValidForkSignalFd(c.forkSignalFd) && isValidForkSignalFd(c.forkKillFd)
}
// Stdout returns the stdout writer.
@@ -812,6 +827,10 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
app.Flag("cert-format", "SSH certificate format").StringVar(&cf.CertificateFormat)
app.Flag("trace", "Capture and export distributed traces").Hidden().BoolVar(&cf.SampleTraces)
app.Flag("trace-exporter", "An OTLP exporter URL to send spans to. Note - only tsh spans will be included.").Hidden().StringVar(&cf.TraceExporter)
+ // This flag only applies to tsh ssh; it's defined here to make configuring
+ // the re-exec command easier.
+ app.Flag("fork-signal-fd", "File descriptor to signal parent on when forked. Overrides --fork-after-authentication. For internal use only.").Hidden().Uint64Var(&cf.forkSignalFd)
+ app.Flag("fork-kill-fd", "File descriptor to check parent health on when forked. For internal use only.").Hidden().Uint64Var(&cf.forkKillFd)
if !moduleCfg.IsBoringBinary() {
// The user is *never* allowed to do this in FIPS mode.
@@ -891,6 +910,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
ssh.Flag("log-dir", "Directory to log separated command output, when executing on multiple nodes. If set, output from each node will also be labeled in the terminal.").StringVar(&cf.SSHLogDir)
ssh.Flag("no-resume", "Disable SSH connection resumption").Envar(noResumeEnvVar).BoolVar(&cf.DisableSSHResumption)
ssh.Flag("relogin", "Permit performing an authentication attempt on a failed command").Default("true").BoolVar(&cf.Relogin)
+ ssh.Flag("fork-after-authentication", "Run in background after authentication is complete.").Short('f').BoolVar(&cf.ForkAfterAuthentication)
// The following flags are OpenSSH compatibility flags. They are used for
// users that alias "ssh" to "tsh ssh." The following OpenSSH flags are
// implemented. From "man 1 ssh":
@@ -1379,6 +1399,33 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
}
}
+ // Handle fork after authentication.
+ if cf.ForkAfterAuthentication && !cf.isForkAuthChild() {
+ if len(cf.RemoteCommand) == 0 {
+ return trace.BadParameter("fork after authentication not allowed for interactive sessions")
+ }
+ forkParams := reexec.ForkAuthenticateParams{
+ GetArgs: func(signalFd, killFd uint64) []string {
+ return append([]string{
+ // fd flags go immediately after `tsh`.
+ "--fork-signal-fd", strconv.FormatUint(signalFd, 10),
+ "--fork-kill-fd", strconv.FormatUint(killFd, 10),
+ }, args...)
+ },
+ Stdin: cf.Stdin(),
+ Stdout: cf.Stdout(),
+ Stderr: cf.Stderr(),
+ }
+ if err := reexec.RunForkAuthenticate(cf.Context, forkParams); err != nil {
+ var execErr *exec.ExitError
+ if errors.As(trace.Unwrap(err), &execErr) {
+ err = &common.ExitCodeError{Code: execErr.ExitCode()}
+ }
+ return trace.Wrap(err)
+ }
+ return nil
+ }
+
// Remove HTTPS:// in proxy parameter as https is automatically added
cf.Proxy = strings.TrimPrefix(cf.Proxy, "https://")
cf.Proxy = strings.TrimPrefix(cf.Proxy, "HTTPS://")
@@ -4004,6 +4051,29 @@ func onResolve(cf *CLIConf) error {
// onSSH executes 'tsh ssh' command
func onSSH(cf *CLIConf) error {
+ // Handle fork after authentication.
+ var disownSignal *os.File
+ var forkAuthSuccessful atomic.Bool
+ if cf.isForkAuthChild() {
+ ctx, cancel := context.WithCancel(cf.Context)
+ cf.Context = ctx
+ // Prep files.
+ disownSignal = newSignalFile(cf.forkSignalFd)
+ defer disownSignal.Close()
+ killSignal := newSignalFile(cf.forkKillFd)
+ defer killSignal.Close()
+
+ // Watch kill signal to check when parent exits. If the read returns before
+ // the child finishes authentication, the parent has died and the child
+ // needs to die too.
+ go func() {
+ err := <-reexec.NotifyFileSignal(killSignal)
+ if err != nil && !forkAuthSuccessful.Load() {
+ cancel()
+ }
+ }()
+ }
+
// If "tsh ssh -V" is invoked, tsh is in OpenSSH compatibility mode, show
// the version and exit.
if cf.ShowVersion {
@@ -4041,7 +4111,7 @@ func onSSH(cf *CLIConf) error {
cf.RemoteCommand = cf.RemoteCommand[1:]
}
- tc.Stdin = os.Stdin
+ tc.Stdin = cf.Stdin()
err = retryWithAccessRequest(cf, tc, func() error {
sshFunc := func() error {
var opts []func(*client.SSHOptions)
@@ -4049,6 +4119,18 @@ func onSSH(cf *CLIConf) error {
opts = append(opts, client.WithLocalCommandExecutor(runLocalCommand))
}
+ if disownSignal != nil {
+ opts = append(opts, client.WithForkAfterAuthentication(func() error {
+ newStdin, err := replaceStdin()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ tc.Stdin = newStdin
+ forkAuthSuccessful.Store(true)
+ return trace.Wrap(reexec.SignalAndClose(disownSignal))
+ }))
+ }
+
return tc.SSH(cf.Context, cf.RemoteCommand, opts...)
}
if !cf.Relogin {
diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go
index 81b56f63fe03e..aa39623bd7e2e 100644
--- a/tool/tsh/common/tsh_test.go
+++ b/tool/tsh/common/tsh_test.go
@@ -201,10 +201,11 @@ func handleReexec() {
})
}
- // Allows test to refer to tsh binary in tests.
- // Needed for tests that generate OpenSSH config by tsh config command where
- // tsh proxy ssh command is used as ProxyCommand.
- if os.Getenv(tshBinMainTestEnv) != "" {
+ // Re-exec tsh commands. Needed for:
+ // - Tests that generate OpenSSH config by tsh config command where
+ // tsh proxy ssh command is used as ProxyCommand.
+ // - Fork after authentication.
+ if os.Getenv(tshBinMainTestEnv) != "" || (len(os.Args) >= 2 && os.Args[1] == "--fork-signal-fd") {
if os.Getenv(tshBinMainTestOneshotEnv) != "" {
// unset this env var so child processes started by 'tsh ssh'
// will be executed correctly below.
@@ -224,8 +225,7 @@ func handleReexec() {
os.Exit(0)
}
- // If the test is re-executing itself, execute the command that comes over
- // the pipe. Used to test tsh ssh command.
+ // Re-exec teleport commands. Used to test tsh ssh command.
if srv.IsReexec() {
srv.RunAndExit(os.Args[1])
}
@@ -7456,3 +7456,113 @@ func prepareCLIOptionForReadingLoggingOpts() (func(t *testing.T) loggingOpts, Cl
return mustReadLoggingOpts, setLoggingOptsFromCLIConf
}
+
+func TestSSHForkAfterAuthentication(t *testing.T) {
+ u, err := user.Current()
+ require.NoError(t, err)
+
+ // Create resources.
+ accessRole, err := types.NewRole("node-access", types.RoleSpecV6{
+ Allow: types.RoleConditions{
+ NodeLabels: types.Labels{
+ types.Wildcard: []string{types.Wildcard},
+ },
+ Logins: []string{u.Username},
+ },
+ })
+ require.NoError(t, err)
+ connector := mockConnector(t)
+ alice, err := types.NewUser("alice@example.com")
+ require.NoError(t, err)
+ alice.SetRoles([]string{accessRole.GetName()})
+
+ tsrv := testserver.MakeTestServer(t,
+ testserver.WithSSHLabel("foo", "bar"),
+ testserver.WithBootstrap(connector, accessRole, alice),
+ )
+ t.Cleanup(func() { require.NoError(t, tsrv.Close()) })
+ // We don't need a real second node for multi-node exec, tsh ssh should fail before that.
+ fakeNode, err := types.NewNode("fake", types.SubKindTeleportNode, types.ServerSpecV2{}, map[string]string{"foo": "bar"})
+ require.NoError(t, err)
+ _, err = tsrv.GetAuthServer().UpsertNode(t.Context(), fakeNode)
+ require.NoError(t, err)
+
+ // Use env var instead of homedir mock to preserve across the re-exec.
+ tmpHomeDir := filepath.Join(t.TempDir(), ".tsh")
+ t.Setenv(types.HomeEnvVar, tmpHomeDir)
+ proxyAddr, err := tsrv.ProxyWebAddr()
+ require.NoError(t, err)
+
+ err = Run(t.Context(), []string{
+ "login",
+ "--insecure",
+ "--proxy", proxyAddr.Addr,
+ }, setMockSSOLogin(tsrv.GetAuthServer(), alice, connector.GetName()))
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ target string
+ command []string
+ assertRun assert.ErrorAssertionFunc
+ assertCommandEffect func(t *testing.T, testFile string) bool
+ }{
+ {
+ name: "ok",
+ target: tsrv.Config.Hostname,
+ command: []string{"echo", "hello", ">", "test.txt"},
+ assertRun: assert.NoError,
+ assertCommandEffect: func(t *testing.T, testFile string) bool {
+ return assert.EventuallyWithT(t, func(collect *assert.CollectT) {
+ assert.FileExists(collect, testFile)
+ }, 3*time.Second, 100*time.Millisecond)
+ },
+ },
+ {
+ name: "stdin is closed after disowning",
+ target: tsrv.Config.Hostname,
+ command: []string{"read", "&&", "echo", "should not happen", ">", "test.txt"},
+ assertRun: assert.NoError,
+ assertCommandEffect: func(t *testing.T, testFile string) bool {
+ return assert.Never(t, func() bool {
+ _, err := os.Stat(testFile)
+ return !errors.Is(err, os.ErrNotExist)
+ }, 3*time.Second, time.Second)
+ },
+ },
+ {
+ name: "not allowed on multiple nodes",
+ target: "foo=bar",
+ command: []string{"echo", "hello", ">", "test.txt"},
+ assertRun: assert.Error,
+ },
+ {
+ name: "not allowed for interactive commands",
+ target: tsrv.Config.Hostname,
+ command: []string{},
+ assertRun: assert.Error,
+ },
+ }
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Configure command with a real file path.
+ testFile := filepath.Join(t.TempDir(), "test.txt")
+ cmd := make([]string, 0, len(tc.command))
+ for _, arg := range tc.command {
+ cmd = append(cmd, strings.ReplaceAll(arg, "test.txt", testFile))
+ }
+
+ err := Run(t.Context(), append([]string{
+ "ssh",
+ "--insecure",
+ "-f",
+ u.Username + "@" + tc.target,
+ }, cmd...))
+ tc.assertRun(t, err)
+ if tc.assertCommandEffect != nil {
+ tc.assertCommandEffect(t, testFile)
+ }
+ })
+ }
+}