diff --git a/lib/client/api.go b/lib/client/api.go index 07e04f8b41ef6..9e8bd3b0cb794 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1900,6 +1900,14 @@ type SSHOptions struct { // machine. If provided, it will be used instead of establishing a connection // to the target host and executing the command remotely. LocalCommandExecutor func(string, []string) error + // OnChildAuthenticate is a function to run in the child process during + // --fork-after authentications. It runs after authentication completes + // but before the session begins. + OnChildAuthenticate func() error +} + +func (opts SSHOptions) forkAfterAuthentication() bool { + return opts.OnChildAuthenticate != nil } // WithHostAddress returns a SSHOptions which overrides the @@ -1918,6 +1926,15 @@ func WithLocalCommandExecutor(executor func(string, []string) error) func(*SSHOp } } +// WithForkAfterAuthentication indicates that tsh is currently reexec-ing +// for --fork-after-authentication. The given function is called after +// authentication is complete but before the session starts. +func WithForkAfterAuthentication(onAuthenticate func() error) func(*SSHOptions) { + return func(opt *SSHOptions) { + opt.OnChildAuthenticate = onAuthenticate + } +} + // SSH connects to a node and, if 'command' is specified, executes the command on it, // otherwise runs interactive shell // @@ -1960,9 +1977,14 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun } if len(nodeAddrs) > 1 { + if options.forkAfterAuthentication() { + return &NonRetryableError{ + Err: trace.BadParameter("fork after authentication not supported for commands on multiple nodes"), + } + } return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command) } - return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor) + return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options) } // ConnectToNode attempts to establish a connection to the node resolved to by the provided @@ -2165,7 +2187,7 @@ func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *Cluster return nodeClient, trace.Wrap(err) } -func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, commandExecutor func(string, []string) error) error { +func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, options SSHOptions) error { cluster := clt.ClusterName() ctx, span := tc.Tracer.Start( ctx, @@ -2189,6 +2211,12 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt return trace.Wrap(err) } defer nodeClient.Close() + + if options.OnChildAuthenticate != nil { + if err := options.OnChildAuthenticate(); err != nil { + return trace.Wrap(err) + } + } // If forwarding ports were specified, start port forwarding. if err := tc.startPortForwarding(ctx, nodeClient); err != nil { return trace.Wrap(err) @@ -2220,11 +2248,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // After port forwarding, run a local command that uses the connection, and // then disconnect. - if commandExecutor != nil { + if options.LocalCommandExecutor != nil { if len(tc.Config.LocalForwardPorts) == 0 { fmt.Println("Executing command locally without connecting to any servers. This makes no sense.") } - return commandExecutor(tc.Config.HostLogin, command) + return options.LocalCommandExecutor(tc.Config.HostLogin, command) } if len(command) > 0 { @@ -2259,7 +2287,7 @@ func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, // Issue "shell" request to the first matching node. fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes match the label selector, picking first: %q\n", nodeAddrs[0]) - return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, nil) + return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, SSHOptions{}) } func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *NodeClient) error { diff --git a/lib/client/reexec/reexec.go b/lib/client/reexec/reexec.go new file mode 100644 index 0000000000000..bd382e251d9f7 --- /dev/null +++ b/lib/client/reexec/reexec.go @@ -0,0 +1,148 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package reexec + +import ( + "context" + "errors" + "io" + "os" + "os/exec" + "strings" + + "github.com/gravitational/trace" +) + +// NotifyFileSignal signals on the returned channel when the provided file +// receives a signal (a one-byte read). +func NotifyFileSignal(f *os.File) <-chan error { + errorCh := make(chan error, 1) + go func() { + n, err := f.Read(make([]byte, 1)) + if n > 0 { + errorCh <- nil + } else if err == nil { + // this should be impossible according to the io.Reader contract + errorCh <- io.ErrUnexpectedEOF + } else { + errorCh <- err + } + }() + return errorCh +} + +// SignalAndClose writes a byte to the provided file (to signal a caller of +// NotifyFileSignal) and closes it. +func SignalAndClose(f *os.File) error { + _, err := f.Write([]byte{0x00}) + return trace.NewAggregate(err, f.Close()) +} + +// ForkAuthenticateParams are the parameters to RunForkAuthenticate. +type ForkAuthenticateParams struct { + // GetArgs gets the arguments to re-exec with, excluding the executable + // (equivalent to os.Args[1:]). + GetArgs func(signalFd, killFd uint64) []string + // executable is the executable to run while re-execing. Overridden in tests. + executable string + // Stdin is the child process' stdin. + Stdin io.Reader + // Stdout is the child process' stdout. + Stdout io.Writer + // Stderr is the child process' stderr. + Stderr io.Writer +} + +// RunForkAuthenticate re-execs the current executable and waits for any of +// the following: +// - The child process exits (usually in error). +// - The child process signals the parent that it is ready to be disowned. +// - The context is canceled. +func RunForkAuthenticate(ctx context.Context, params ForkAuthenticateParams) error { + if params.executable == "" { + executable, err := getExecutable() + if err != nil { + return trace.Wrap(err) + } + params.executable = executable + } + cmd := exec.Command(params.executable) + // Set up signal pipes. + disownR, disownW, err := os.Pipe() + if err != nil { + return trace.Wrap(err) + } + killR, killW, err := os.Pipe() + if err != nil { + return trace.Wrap(err) + } + defer func() { + // If the child is still listening, kill it. If the child successfully + // disowned, this will do nothing. + SignalAndClose(killW) + killR.Close() + disownW.Close() + disownR.Close() + }() + + signalFd, killFd := configureReexecForOS(cmd, disownW, killR) + cmd.Args = append(cmd.Args, params.GetArgs(signalFd, killFd)...) + cmd.Args[0] = os.Args[0] + cmd.Stdin = params.Stdin + cmd.Stdout = params.Stdout + cmd.Stderr = params.Stderr + + if err := cmd.Start(); err != nil { + return trace.Wrap(err) + } + + // Clean up parent end of pipes. + if err := disownW.Close(); err != nil { + return trace.NewAggregate(err, killAndWaitProcess(cmd)) + } + if err := killR.Close(); err != nil { + return trace.NewAggregate(err, killAndWaitProcess(cmd)) + } + + select { + case err := <-NotifyFileSignal(disownR): + if err == nil { + return trace.Wrap(cmd.Process.Release()) + } else if errors.Is(err, io.EOF) { + // EOF means the child process exited, no need to report it on top of kill/wait. + return trace.Wrap(killAndWaitProcess(cmd)) + } + return trace.NewAggregate(err, killAndWaitProcess(cmd)) + case <-ctx.Done(): + return trace.NewAggregate(ctx.Err(), killAndWaitProcess(cmd)) + } +} + +func killAndWaitProcess(cmd *exec.Cmd) error { + if err := cmd.Process.Kill(); err != nil { + return trace.Wrap(err) + } + err := cmd.Wait() + var execErr *exec.ExitError + if errors.As(err, &execErr) && execErr.ExitCode() != 0 { + return trace.Wrap(err) + } else if err != nil && strings.Contains(err.Error(), "signal: killed") { + // If the process was successfully killed, there is no issue. + return nil + } + return trace.Wrap(err) +} diff --git a/lib/client/reexec/reexec_darwin.go b/lib/client/reexec/reexec_darwin.go new file mode 100644 index 0000000000000..f4ee34d375374 --- /dev/null +++ b/lib/client/reexec/reexec_darwin.go @@ -0,0 +1,43 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build darwin + +package reexec + +import ( + "os" + "os/exec" + "syscall" + + "github.com/gravitational/trace" +) + +// getExecutable gets the path to the executable that should be used for re-exec. +func getExecutable() (string, error) { + executable, err := os.Executable() + return executable, trace.Wrap(err) +} + +// configureReexecForOS configures the command with files to inherit and +// os-specific tweaks. +func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + } + cmd.ExtraFiles = []*os.File{signal, kill} + return 3, 4 +} diff --git a/lib/client/reexec/reexec_linux.go b/lib/client/reexec/reexec_linux.go new file mode 100644 index 0000000000000..e5d843f4ce820 --- /dev/null +++ b/lib/client/reexec/reexec_linux.go @@ -0,0 +1,40 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build linux + +package reexec + +import ( + "os" + "os/exec" + "syscall" +) + +// getExecutable gets the path to the executable that should be used for re-exec. +func getExecutable() (string, error) { + return "/proc/self/exe", nil +} + +// configureReexecForOS configures the command with files to inherit and +// os-specific tweaks. +func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + } + cmd.ExtraFiles = []*os.File{signal, kill} + return 3, 4 +} diff --git a/lib/client/reexec/reexec_test.go b/lib/client/reexec/reexec_test.go new file mode 100644 index 0000000000000..bfdb10a0041bb --- /dev/null +++ b/lib/client/reexec/reexec_test.go @@ -0,0 +1,185 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package reexec + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils/testutils" +) + +type syncBuffer struct { + buf *bytes.Buffer + mu sync.Mutex +} + +func newSyncBuffer() *syncBuffer { + return &syncBuffer{ + buf: &bytes.Buffer{}, + } +} + +func (rw *syncBuffer) Read(b []byte) (int, error) { + rw.mu.Lock() + defer rw.mu.Unlock() + return rw.buf.Read(b) +} + +func (rw *syncBuffer) Write(b []byte) (int, error) { + rw.mu.Lock() + defer rw.mu.Unlock() + return rw.buf.Write(b) +} + +func (rw *syncBuffer) String() string { + rw.mu.Lock() + defer rw.mu.Unlock() + return rw.buf.String() +} + +func TestRunForkAuthenticate(t *testing.T) { + t.Parallel() + + t.Run("child disowns successfully", func(t *testing.T) { + t.Parallel() + const script = ` + read + # Close signal fd. + echo x >&%d + exec %d>&- + # stdout/err should still work. + echo "stdout: $REPLY" + echo "stderr: $REPLY" >&2 + # Wait to ensure the fd closure is caught before the process ends. + sleep 1 + ` + getArgs := func(signalFd, killFd uint64) []string { + return []string{"-c", fmt.Sprintf(script, signalFd, signalFd)} + } + stdout := newSyncBuffer() + stderr := newSyncBuffer() + params := ForkAuthenticateParams{ + GetArgs: getArgs, + executable: "bash", + Stdin: bytes.NewBufferString("hello\n"), + Stdout: stdout, + Stderr: stderr, + } + + err := RunForkAuthenticate(t.Context(), params) + assert.NoError(t, err) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + assert.Equal(t, "stdout: hello\n", stdout.String()) + assert.Equal(t, "stderr: hello\n", stderr.String()) + }, 10*time.Second, 100*time.Millisecond) + }) + + t.Run("child exits with error", func(t *testing.T) { + t.Parallel() + const script = ` + # Make sure stdin/out/err work. + read + echo "stdout: $REPLY" + echo "stderr: $REPLY" >&2 + # Exit with error. + exit 42 + ` + getArgs := func(signalFd, killFd uint64) []string { + return []string{"-c", script} + } + stdout := newSyncBuffer() + stderr := newSyncBuffer() + params := ForkAuthenticateParams{ + GetArgs: getArgs, + executable: "bash", + Stdin: bytes.NewBufferString("hello\n"), + Stdout: stdout, + Stderr: stderr, + } + + err := RunForkAuthenticate(t.Context(), params) + var execErr *exec.ExitError + if assert.ErrorAs(t, err, &execErr) { + assert.Equal(t, 42, execErr.ExitCode()) + } + assert.Equal(t, "stdout: hello\n", stdout.String()) + assert.Equal(t, "stderr: hello\n", stderr.String()) + }) + + t.Run("context canceled", func(t *testing.T) { + t.Parallel() + getArgs := func(_, _ uint64) []string { + return []string{"-c", ` + # Make sure stdin/out/err work. + read + echo "stdout: $REPLY" + echo "stderr: $REPLY" >&2 + # wait for cancellation + sleep 2 + # should not be executed + echo "extra output" + `} + } + stdout := newSyncBuffer() + stderr := newSyncBuffer() + params := ForkAuthenticateParams{ + GetArgs: getArgs, + executable: "bash", + Stdin: bytes.NewBufferString("hello\n"), + Stdout: stdout, + Stderr: stderr, + } + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + errorCh := make(chan error, 1) + testutils.RunTestBackgroundTask(ctx, t, &testutils.TestBackgroundTask{ + Name: "RunForkAuthenticate", + Task: func(ctx context.Context) error { + errorCh <- RunForkAuthenticate(ctx, params) + return nil + }, + }) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + assert.Equal(t, "stdout: hello\n", stdout.String()) + assert.Equal(t, "stderr: hello\n", stderr.String()) + }, 10*time.Second, 100*time.Millisecond) + + cancel() + select { + case err := <-errorCh: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(5 * time.Second): + require.Fail(t, "timed out waiting for child to finish") + } + + require.Never(t, func() bool { + return strings.Contains(stdout.String(), "extra output") + }, 3*time.Second, time.Second) + }) +} diff --git a/lib/client/reexec/reexec_windows.go b/lib/client/reexec/reexec_windows.go new file mode 100644 index 0000000000000..be690592c4dc0 --- /dev/null +++ b/lib/client/reexec/reexec_windows.go @@ -0,0 +1,48 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build windows + +package reexec + +import ( + "os" + "os/exec" + "runtime" + "syscall" + + "github.com/gravitational/trace" +) + +// getExecutable gets the path to the executable that should be used for re-exec. +func getExecutable() (string, error) { + executable, err := os.Executable() + return executable, trace.Wrap(err) +} + +// configureReexecForOS configures the command with files to inherit and +// os-specific tweaks. +func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) { + // Prevent handle from being closed when signal is garbage collected. + runtime.SetFinalizer(signal, nil) + cmd.SysProcAttr = &syscall.SysProcAttr{ + AdditionalInheritedHandles: []syscall.Handle{ + syscall.Handle(signal.Fd()), + syscall.Handle(kill.Fd()), + }, + } + return uint64(signal.Fd()), uint64(kill.Fd()) +} diff --git a/tool/tsh/common/dup2_linux.go b/tool/tsh/common/dup2_linux.go new file mode 100644 index 0000000000000..52ca8b0585e2c --- /dev/null +++ b/tool/tsh/common/dup2_linux.go @@ -0,0 +1,34 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build linux + +package common + +import "syscall" + +// dup2 implements syscall.Dup2(oldfd, newfd) in a way that works on all +// current Linux platforms, and likely on any new platforms. New platforms +// such as ARM64 do not implement syscall.Dup2() instead implementing +// syscall.Dup3() which is largely a superset, with one special case. +func dup2(oldfd, newfd int) error { + if oldfd == newfd { + // dup2 would do nothing in this case, but dup3 returns an error. + // Emulate dup2 behavior. + return nil + } + return syscall.Dup3(oldfd, newfd, 0) +} diff --git a/tool/tsh/common/dup2_unix.go b/tool/tsh/common/dup2_unix.go new file mode 100644 index 0000000000000..47aef96884851 --- /dev/null +++ b/tool/tsh/common/dup2_unix.go @@ -0,0 +1,28 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build unix && !linux + +package common + +import "syscall" + +// dup2 wraps syscall.Dup2(oldfd, newfd) on non-linux unix platforms. The +// linux implementation uses syscall.Dup3() as Dup2() is not available +// on all linux platforms. +func dup2(oldfd, newfd int) error { + return syscall.Dup2(oldfd, newfd) +} diff --git a/tool/tsh/common/reexec_unix.go b/tool/tsh/common/reexec_unix.go new file mode 100644 index 0000000000000..4a707ab8a62ca --- /dev/null +++ b/tool/tsh/common/reexec_unix.go @@ -0,0 +1,69 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build unix + +package common + +import ( + "os" + "syscall" + + "github.com/gravitational/trace" +) + +func isValidForkSignalFd(fd uint64) bool { + // Don't allow stdin, stdout, or stderr. + return fd > 2 +} + +// newSignalFile creates a signaling file for --fork-after-authentication from +// a file descriptor. +func newSignalFile(fd uint64) *os.File { + syscall.CloseOnExec(int(fd)) + return os.NewFile(uintptr(fd), "disown") +} + +// replaceStdin returns a file for /dev/null that should be used from now +// on instead of stdin. +func replaceStdin() (*os.File, error) { + devNull, err := os.Open(os.DevNull) + if err != nil { + return nil, trace.Wrap(err) + } + rc, err := devNull.SyscallConn() + if err != nil { + _ = devNull.Close() + return nil, trace.Wrap(err) + } + var dupErr error + if ctrlErr := rc.Control(func(fd uintptr) { + dupErr = dup2(int(fd), syscall.Stdin) + // dup2() is sufficient here as the three stdio file + // descriptors must not be O_CLOEXEC. Darwin does not have + // dup3(), so would need to resort to syscall.ForkLock + // shenanigans if we did need to set O_CLOEXEC. + }); ctrlErr != nil { + _ = devNull.Close() + return nil, trace.Wrap(ctrlErr) + } + if dupErr != nil { + // this is the error from dup2 + _ = devNull.Close() + return nil, trace.Wrap(err) + } + return devNull, err +} diff --git a/tool/tsh/common/reexec_windows.go b/tool/tsh/common/reexec_windows.go new file mode 100644 index 0000000000000..bbfc96a077089 --- /dev/null +++ b/tool/tsh/common/reexec_windows.go @@ -0,0 +1,45 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build windows + +package common + +import ( + "os" + "syscall" + + "github.com/gravitational/trace" +) + +func isValidForkSignalFd(fd uint64) bool { + // Don't allow NULL. + return fd != 0 +} + +// newSignalFile creates a signaling file for --fork-after-authentication from +// a file descriptor. +func newSignalFile(fd uint64) *os.File { + syscall.CloseOnExec(syscall.Handle(fd)) + return os.NewFile(uintptr(fd), "disown") +} + +// replaceStdin returns a file for /dev/null that should be used from now +// on instead of stdin. +func replaceStdin() (*os.File, error) { + devNull, err := os.Open(os.DevNull) + return devNull, trace.Wrap(err) +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 550811255519b..56d680920562c 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -80,6 +80,7 @@ import ( "github.com/gravitational/teleport/lib/client" dbprofile "github.com/gravitational/teleport/lib/client/db" "github.com/gravitational/teleport/lib/client/identityfile" + "github.com/gravitational/teleport/lib/client/reexec" "github.com/gravitational/teleport/lib/defaults" dtauthn "github.com/gravitational/teleport/lib/devicetrust/authn" dtenroll "github.com/gravitational/teleport/lib/devicetrust/enroll" @@ -618,6 +619,20 @@ type CLIConf struct { // atomic here is overkill as the CLIConf is generally consumed sequentially. However, occasionally // we need concurrency safety, such as for [forEachProfileParallel]. clientStoreSet int32 + + // ForkAfterAuthentication indicates that tsh should go into the background + // after authentication. + ForkAfterAuthentication bool + // forkSignalFd is the file descriptor for the child process to signal the + // parent when re-execing. + forkSignalFd uint64 + // forkKillFd is the file descriptor for the child process to check the + // parent's state when re-execing. + forkKillFd uint64 +} + +func (c *CLIConf) isForkAuthChild() bool { + return isValidForkSignalFd(c.forkSignalFd) && isValidForkSignalFd(c.forkKillFd) } // Stdout returns the stdout writer. @@ -812,6 +827,10 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { app.Flag("cert-format", "SSH certificate format").StringVar(&cf.CertificateFormat) app.Flag("trace", "Capture and export distributed traces").Hidden().BoolVar(&cf.SampleTraces) app.Flag("trace-exporter", "An OTLP exporter URL to send spans to. Note - only tsh spans will be included.").Hidden().StringVar(&cf.TraceExporter) + // This flag only applies to tsh ssh; it's defined here to make configuring + // the re-exec command easier. + app.Flag("fork-signal-fd", "File descriptor to signal parent on when forked. Overrides --fork-after-authentication. For internal use only.").Hidden().Uint64Var(&cf.forkSignalFd) + app.Flag("fork-kill-fd", "File descriptor to check parent health on when forked. For internal use only.").Hidden().Uint64Var(&cf.forkKillFd) if !moduleCfg.IsBoringBinary() { // The user is *never* allowed to do this in FIPS mode. @@ -891,6 +910,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { ssh.Flag("log-dir", "Directory to log separated command output, when executing on multiple nodes. If set, output from each node will also be labeled in the terminal.").StringVar(&cf.SSHLogDir) ssh.Flag("no-resume", "Disable SSH connection resumption").Envar(noResumeEnvVar).BoolVar(&cf.DisableSSHResumption) ssh.Flag("relogin", "Permit performing an authentication attempt on a failed command").Default("true").BoolVar(&cf.Relogin) + ssh.Flag("fork-after-authentication", "Run in background after authentication is complete.").Short('f').BoolVar(&cf.ForkAfterAuthentication) // The following flags are OpenSSH compatibility flags. They are used for // users that alias "ssh" to "tsh ssh." The following OpenSSH flags are // implemented. From "man 1 ssh": @@ -1379,6 +1399,33 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { } } + // Handle fork after authentication. + if cf.ForkAfterAuthentication && !cf.isForkAuthChild() { + if len(cf.RemoteCommand) == 0 { + return trace.BadParameter("fork after authentication not allowed for interactive sessions") + } + forkParams := reexec.ForkAuthenticateParams{ + GetArgs: func(signalFd, killFd uint64) []string { + return append([]string{ + // fd flags go immediately after `tsh`. + "--fork-signal-fd", strconv.FormatUint(signalFd, 10), + "--fork-kill-fd", strconv.FormatUint(killFd, 10), + }, args...) + }, + Stdin: cf.Stdin(), + Stdout: cf.Stdout(), + Stderr: cf.Stderr(), + } + if err := reexec.RunForkAuthenticate(cf.Context, forkParams); err != nil { + var execErr *exec.ExitError + if errors.As(trace.Unwrap(err), &execErr) { + err = &common.ExitCodeError{Code: execErr.ExitCode()} + } + return trace.Wrap(err) + } + return nil + } + // Remove HTTPS:// in proxy parameter as https is automatically added cf.Proxy = strings.TrimPrefix(cf.Proxy, "https://") cf.Proxy = strings.TrimPrefix(cf.Proxy, "HTTPS://") @@ -4004,6 +4051,29 @@ func onResolve(cf *CLIConf) error { // onSSH executes 'tsh ssh' command func onSSH(cf *CLIConf) error { + // Handle fork after authentication. + var disownSignal *os.File + var forkAuthSuccessful atomic.Bool + if cf.isForkAuthChild() { + ctx, cancel := context.WithCancel(cf.Context) + cf.Context = ctx + // Prep files. + disownSignal = newSignalFile(cf.forkSignalFd) + defer disownSignal.Close() + killSignal := newSignalFile(cf.forkKillFd) + defer killSignal.Close() + + // Watch kill signal to check when parent exits. If the read returns before + // the child finishes authentication, the parent has died and the child + // needs to die too. + go func() { + err := <-reexec.NotifyFileSignal(killSignal) + if err != nil && !forkAuthSuccessful.Load() { + cancel() + } + }() + } + // If "tsh ssh -V" is invoked, tsh is in OpenSSH compatibility mode, show // the version and exit. if cf.ShowVersion { @@ -4041,7 +4111,7 @@ func onSSH(cf *CLIConf) error { cf.RemoteCommand = cf.RemoteCommand[1:] } - tc.Stdin = os.Stdin + tc.Stdin = cf.Stdin() err = retryWithAccessRequest(cf, tc, func() error { sshFunc := func() error { var opts []func(*client.SSHOptions) @@ -4049,6 +4119,18 @@ func onSSH(cf *CLIConf) error { opts = append(opts, client.WithLocalCommandExecutor(runLocalCommand)) } + if disownSignal != nil { + opts = append(opts, client.WithForkAfterAuthentication(func() error { + newStdin, err := replaceStdin() + if err != nil { + return trace.Wrap(err) + } + tc.Stdin = newStdin + forkAuthSuccessful.Store(true) + return trace.Wrap(reexec.SignalAndClose(disownSignal)) + })) + } + return tc.SSH(cf.Context, cf.RemoteCommand, opts...) } if !cf.Relogin { diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index 81b56f63fe03e..aa39623bd7e2e 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -201,10 +201,11 @@ func handleReexec() { }) } - // Allows test to refer to tsh binary in tests. - // Needed for tests that generate OpenSSH config by tsh config command where - // tsh proxy ssh command is used as ProxyCommand. - if os.Getenv(tshBinMainTestEnv) != "" { + // Re-exec tsh commands. Needed for: + // - Tests that generate OpenSSH config by tsh config command where + // tsh proxy ssh command is used as ProxyCommand. + // - Fork after authentication. + if os.Getenv(tshBinMainTestEnv) != "" || (len(os.Args) >= 2 && os.Args[1] == "--fork-signal-fd") { if os.Getenv(tshBinMainTestOneshotEnv) != "" { // unset this env var so child processes started by 'tsh ssh' // will be executed correctly below. @@ -224,8 +225,7 @@ func handleReexec() { os.Exit(0) } - // If the test is re-executing itself, execute the command that comes over - // the pipe. Used to test tsh ssh command. + // Re-exec teleport commands. Used to test tsh ssh command. if srv.IsReexec() { srv.RunAndExit(os.Args[1]) } @@ -7456,3 +7456,113 @@ func prepareCLIOptionForReadingLoggingOpts() (func(t *testing.T) loggingOpts, Cl return mustReadLoggingOpts, setLoggingOptsFromCLIConf } + +func TestSSHForkAfterAuthentication(t *testing.T) { + u, err := user.Current() + require.NoError(t, err) + + // Create resources. + accessRole, err := types.NewRole("node-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + Logins: []string{u.Username}, + }, + }) + require.NoError(t, err) + connector := mockConnector(t) + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetRoles([]string{accessRole.GetName()}) + + tsrv := testserver.MakeTestServer(t, + testserver.WithSSHLabel("foo", "bar"), + testserver.WithBootstrap(connector, accessRole, alice), + ) + t.Cleanup(func() { require.NoError(t, tsrv.Close()) }) + // We don't need a real second node for multi-node exec, tsh ssh should fail before that. + fakeNode, err := types.NewNode("fake", types.SubKindTeleportNode, types.ServerSpecV2{}, map[string]string{"foo": "bar"}) + require.NoError(t, err) + _, err = tsrv.GetAuthServer().UpsertNode(t.Context(), fakeNode) + require.NoError(t, err) + + // Use env var instead of homedir mock to preserve across the re-exec. + tmpHomeDir := filepath.Join(t.TempDir(), ".tsh") + t.Setenv(types.HomeEnvVar, tmpHomeDir) + proxyAddr, err := tsrv.ProxyWebAddr() + require.NoError(t, err) + + err = Run(t.Context(), []string{ + "login", + "--insecure", + "--proxy", proxyAddr.Addr, + }, setMockSSOLogin(tsrv.GetAuthServer(), alice, connector.GetName())) + require.NoError(t, err) + + tests := []struct { + name string + target string + command []string + assertRun assert.ErrorAssertionFunc + assertCommandEffect func(t *testing.T, testFile string) bool + }{ + { + name: "ok", + target: tsrv.Config.Hostname, + command: []string{"echo", "hello", ">", "test.txt"}, + assertRun: assert.NoError, + assertCommandEffect: func(t *testing.T, testFile string) bool { + return assert.EventuallyWithT(t, func(collect *assert.CollectT) { + assert.FileExists(collect, testFile) + }, 3*time.Second, 100*time.Millisecond) + }, + }, + { + name: "stdin is closed after disowning", + target: tsrv.Config.Hostname, + command: []string{"read", "&&", "echo", "should not happen", ">", "test.txt"}, + assertRun: assert.NoError, + assertCommandEffect: func(t *testing.T, testFile string) bool { + return assert.Never(t, func() bool { + _, err := os.Stat(testFile) + return !errors.Is(err, os.ErrNotExist) + }, 3*time.Second, time.Second) + }, + }, + { + name: "not allowed on multiple nodes", + target: "foo=bar", + command: []string{"echo", "hello", ">", "test.txt"}, + assertRun: assert.Error, + }, + { + name: "not allowed for interactive commands", + target: tsrv.Config.Hostname, + command: []string{}, + assertRun: assert.Error, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + // Configure command with a real file path. + testFile := filepath.Join(t.TempDir(), "test.txt") + cmd := make([]string, 0, len(tc.command)) + for _, arg := range tc.command { + cmd = append(cmd, strings.ReplaceAll(arg, "test.txt", testFile)) + } + + err := Run(t.Context(), append([]string{ + "ssh", + "--insecure", + "-f", + u.Username + "@" + tc.target, + }, cmd...)) + tc.assertRun(t, err) + if tc.assertCommandEffect != nil { + tc.assertCommandEffect(t, testFile) + } + }) + } +}