Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,14 @@ type SSHOptions struct {
// machine. If provided, it will be used instead of establishing a connection
// to the target host and executing the command remotely.
LocalCommandExecutor func(string, []string) error
// OnChildAuthenticate is a function to run in the child process during
// --fork-after authentications. It runs after authentication completes
// but before the session begins.
OnChildAuthenticate func() error
}

func (opts SSHOptions) forkAfterAuthentication() bool {
return opts.OnChildAuthenticate != nil
}

// WithHostAddress returns a SSHOptions which overrides the
Expand All @@ -1918,6 +1926,15 @@ func WithLocalCommandExecutor(executor func(string, []string) error) func(*SSHOp
}
}

// WithForkAfterAuthentication indicates that tsh is currently reexec-ing
// for --fork-after-authentication. The given function is called after
// authentication is complete but before the session starts.
func WithForkAfterAuthentication(onAuthenticate func() error) func(*SSHOptions) {
return func(opt *SSHOptions) {
opt.OnChildAuthenticate = onAuthenticate
}
}

// SSH connects to a node and, if 'command' is specified, executes the command on it,
// otherwise runs interactive shell
//
Expand Down Expand Up @@ -1960,9 +1977,14 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
}

if len(nodeAddrs) > 1 {
if options.forkAfterAuthentication() {
return &NonRetryableError{
Err: trace.BadParameter("fork after authentication not supported for commands on multiple nodes"),
}
}
return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command)
}
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor)
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options)
}

// ConnectToNode attempts to establish a connection to the node resolved to by the provided
Expand Down Expand Up @@ -2165,7 +2187,7 @@ func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *Cluster
return nodeClient, trace.Wrap(err)
}

func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, commandExecutor func(string, []string) error) error {
func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, options SSHOptions) error {
cluster := clt.ClusterName()
ctx, span := tc.Tracer.Start(
ctx,
Expand All @@ -2189,6 +2211,12 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
return trace.Wrap(err)
}
defer nodeClient.Close()

if options.OnChildAuthenticate != nil {
if err := options.OnChildAuthenticate(); err != nil {
return trace.Wrap(err)
}
}
// If forwarding ports were specified, start port forwarding.
if err := tc.startPortForwarding(ctx, nodeClient); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -2220,11 +2248,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt

// After port forwarding, run a local command that uses the connection, and
// then disconnect.
if commandExecutor != nil {
if options.LocalCommandExecutor != nil {
if len(tc.Config.LocalForwardPorts) == 0 {
fmt.Println("Executing command locally without connecting to any servers. This makes no sense.")
}
return commandExecutor(tc.Config.HostLogin, command)
return options.LocalCommandExecutor(tc.Config.HostLogin, command)
}

if len(command) > 0 {
Expand Down Expand Up @@ -2259,7 +2287,7 @@ func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context,

// Issue "shell" request to the first matching node.
fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes match the label selector, picking first: %q\n", nodeAddrs[0])
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, nil)
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, SSHOptions{})
}

func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *NodeClient) error {
Expand Down
148 changes: 148 additions & 0 deletions lib/client/reexec/reexec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package reexec

import (
"context"
"errors"
"io"
"os"
"os/exec"
"strings"

"github.com/gravitational/trace"
)

// NotifyFileSignal signals on the returned channel when the provided file
// receives a signal (a one-byte read).
func NotifyFileSignal(f *os.File) <-chan error {
errorCh := make(chan error, 1)
go func() {
n, err := f.Read(make([]byte, 1))
if n > 0 {
errorCh <- nil
} else if err == nil {
// this should be impossible according to the io.Reader contract
errorCh <- io.ErrUnexpectedEOF
} else {
errorCh <- err
}
}()
return errorCh
}

// SignalAndClose writes a byte to the provided file (to signal a caller of
// NotifyFileSignal) and closes it.
func SignalAndClose(f *os.File) error {
_, err := f.Write([]byte{0x00})
return trace.NewAggregate(err, f.Close())
}

// ForkAuthenticateParams are the parameters to RunForkAuthenticate.
type ForkAuthenticateParams struct {
// GetArgs gets the arguments to re-exec with, excluding the executable
// (equivalent to os.Args[1:]).
GetArgs func(signalFd, killFd uint64) []string
// executable is the executable to run while re-execing. Overridden in tests.
executable string
// Stdin is the child process' stdin.
Stdin io.Reader
// Stdout is the child process' stdout.
Stdout io.Writer
// Stderr is the child process' stderr.
Stderr io.Writer
}

// RunForkAuthenticate re-execs the current executable and waits for any of
// the following:
// - The child process exits (usually in error).
// - The child process signals the parent that it is ready to be disowned.
// - The context is canceled.
func RunForkAuthenticate(ctx context.Context, params ForkAuthenticateParams) error {
if params.executable == "" {
executable, err := getExecutable()
if err != nil {
return trace.Wrap(err)
}
params.executable = executable
}
cmd := exec.Command(params.executable)
// Set up signal pipes.
disownR, disownW, err := os.Pipe()
if err != nil {
return trace.Wrap(err)
}
killR, killW, err := os.Pipe()
if err != nil {
return trace.Wrap(err)
}
defer func() {
// If the child is still listening, kill it. If the child successfully
// disowned, this will do nothing.
SignalAndClose(killW)
killR.Close()
disownW.Close()
disownR.Close()
}()

signalFd, killFd := configureReexecForOS(cmd, disownW, killR)
cmd.Args = append(cmd.Args, params.GetArgs(signalFd, killFd)...)
cmd.Args[0] = os.Args[0]
cmd.Stdin = params.Stdin
cmd.Stdout = params.Stdout
cmd.Stderr = params.Stderr

if err := cmd.Start(); err != nil {
return trace.Wrap(err)
}

// Clean up parent end of pipes.
if err := disownW.Close(); err != nil {
return trace.NewAggregate(err, killAndWaitProcess(cmd))
}
if err := killR.Close(); err != nil {
return trace.NewAggregate(err, killAndWaitProcess(cmd))
}

select {
case err := <-NotifyFileSignal(disownR):
if err == nil {
return trace.Wrap(cmd.Process.Release())
} else if errors.Is(err, io.EOF) {
// EOF means the child process exited, no need to report it on top of kill/wait.
return trace.Wrap(killAndWaitProcess(cmd))
}
return trace.NewAggregate(err, killAndWaitProcess(cmd))
case <-ctx.Done():
return trace.NewAggregate(ctx.Err(), killAndWaitProcess(cmd))
}
}

func killAndWaitProcess(cmd *exec.Cmd) error {
if err := cmd.Process.Kill(); err != nil {
return trace.Wrap(err)
}
err := cmd.Wait()
var execErr *exec.ExitError
if errors.As(err, &execErr) && execErr.ExitCode() != 0 {
return trace.Wrap(err)
} else if err != nil && strings.Contains(err.Error(), "signal: killed") {
// If the process was successfully killed, there is no issue.
return nil
}
return trace.Wrap(err)
}
43 changes: 43 additions & 0 deletions lib/client/reexec/reexec_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

//go:build darwin

package reexec

import (
"os"
"os/exec"
"syscall"

"github.com/gravitational/trace"
)

// getExecutable gets the path to the executable that should be used for re-exec.
func getExecutable() (string, error) {
executable, err := os.Executable()
return executable, trace.Wrap(err)
}

// configureReexecForOS configures the command with files to inherit and
// os-specific tweaks.
func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
cmd.ExtraFiles = []*os.File{signal, kill}
return 3, 4
}
40 changes: 40 additions & 0 deletions lib/client/reexec/reexec_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

//go:build linux

package reexec

import (
"os"
"os/exec"
"syscall"
)

// getExecutable gets the path to the executable that should be used for re-exec.
func getExecutable() (string, error) {
return "/proc/self/exe", nil
}

// configureReexecForOS configures the command with files to inherit and
// os-specific tweaks.
func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
cmd.ExtraFiles = []*os.File{signal, kill}
return 3, 4
}
Loading
Loading