diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index 9b1a2f1149be5..4c685f4a24e52 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -22,6 +22,7 @@ import ( "sync" "sync/atomic" + "github.com/google/uuid" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" @@ -29,6 +30,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" ) // Client is a wrapper around ssh.Client that adds tracing support. @@ -166,9 +168,65 @@ func (c *Client) OpenChannel( }, reqs, err } -// NewSession creates a new SSH session that is passed tracing context -// so that spans may be correlated properly over the ssh connection. +// SessionParams are session parameters supported by Teleport to provide additional +// session context or parameters to the server. +type SessionParams struct { + // WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server. + WebProxyAddr string + // Reason is a reason attached to started sessions meant to describe their intent. + Reason string + // Invited is a list of people invited to a session. + Invited []string + // DisplayParticipantRequirements is set if debug information about participants requirements + // should be printed in moderated sessions. + DisplayParticipantRequirements bool + // JoinSessionID is the ID of a session to join. + JoinSessionID string + // JoinMode is the participant mode to join the session with. + // Required if JoinSessionID is set. + JoinMode types.SessionParticipantMode + // ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session + // to check for valid FileTransferRequests. + ModeratedSessionID string +} + +// ParseSessionParams unmarshals session parameters which have been [ssh.Marshal]ed by the client +// and provided as extra data in the session channel request. If the provided data is empty, nil params +// will be returned with a nil error. +func ParseSessionParams(data []byte) (*SessionParams, error) { + if len(data) == 0 { + return nil, nil + } + + var params SessionParams + if err := ssh.Unmarshal(data, ¶ms); err != nil { + return nil, trace.Wrap(err) + } + + if params.JoinSessionID != "" { + if _, err := uuid.Parse(params.JoinSessionID); err != nil { + return nil, trace.Wrap(err, "failed to parse join session ID: %v", params.JoinSessionID) + } + + switch params.JoinMode { + case types.SessionModeratorMode, types.SessionObserverMode, types.SessionPeerMode: + default: + return nil, trace.BadParameter("Unrecognized session participant mode: %q", params.JoinMode) + } + } + + return ¶ms, nil +} + +// NewSession creates a new SSH session. This session is passed a tracing context so that +// spans may be correlated properly over the ssh connection. func (c *Client) NewSession(ctx context.Context) (*Session, error) { + return c.NewSessionWithParams(ctx, nil) +} + +// NewSessionWithParams creates a new SSH session with the given (optional) params. This session is +// passed a tracing context so that spans may be correlated properly over the ssh connection. +func (c *Client) NewSessionWithParams(ctx context.Context, sessionParams *SessionParams) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -195,9 +253,16 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { contexts: make(map[string][]context.Context), } + // If we are connected to a Teleport server, send session params in the session request. + // If the server does not support session parameters in the extra data, it will be ignored. + var sessionData []byte + if sessionParams != nil && c.capability == tracingSupported { + sessionData = ssh.Marshal(sessionParams) + } + // open a session manually so we can take ownership of the // requests chan - ch, reqs, err := wrapper.OpenChannel("session", nil) + ch, reqs, err := wrapper.OpenChannel("session", sessionData) if err != nil { return nil, trace.Wrap(err) } @@ -218,7 +283,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { } // RequestHandlerFn is an ssh request handler function. -type RequestHandlerFn func(ctx context.Context, ch *ssh.Request) +type RequestHandlerFn func(ctx context.Context, req *ssh.Request) // HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the // provided type within a session. If the type is already being handled, an error is returned. diff --git a/integration/integration_test.go b/integration/integration_test.go index 7fc6626c11d43..deed9d2e8f49e 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1133,7 +1133,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { ) assert.NoError(t, err) - errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) + errCh <- nodeClient.RunInteractiveShell(ctx, "", "", nil) assert.NoError(t, nodeClient.Close()) }() @@ -7978,7 +7978,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { isNilOrEOFErr(t, transferSess.Close()) }) - err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + err = transferSess.Setenv(ctx, telesftp.EnvModeratedSessionID, sessTracker.GetSessionID()) require.NoError(t, err) err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem) @@ -8040,7 +8040,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { require.NoError(t, transferSess.Close()) }) - err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + err = transferSess.Setenv(ctx, telesftp.EnvModeratedSessionID, sessTracker.GetSessionID()) require.NoError(t, err) // Test that only operations needed to complete the download diff --git a/lib/client/api.go b/lib/client/api.go index 9cf506d7ea74a..37abe98c5d5fe 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -28,7 +28,6 @@ import ( "errors" "fmt" "io" - "maps" "net" "net/url" "os" @@ -94,7 +93,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -2285,7 +2283,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // Reuse the existing nodeClient we connected above. return nodeClient.RunCommand(ctx, command) } - return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)) + return trace.Wrap(nodeClient.RunInteractiveShell(ctx, "", "", nil)) } func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error { @@ -2439,7 +2437,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan } // running shell with a given session means "join" it: - err = nc.RunInteractiveShell(ctx, mode, session, beforeStart) + err = nc.RunInteractiveShell(ctx, sessionID.String(), mode, beforeStart) return trace.Wrap(err) } @@ -2685,7 +2683,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient, return trace.Wrap(err) } - return trace.Wrap(nodeClient.TransferFiles(ctx, cfg)) + return trace.Wrap(nodeClient.TransferFiles(ctx, cfg, "" /*moderatedSessionID*/)) } // ListNodesWithFilters returns all nodes that match the filters in the current cluster @@ -3113,18 +3111,6 @@ func (tc *TeleportClient) writeCommandResults(nodes []execResult) error { return nil } -func (tc *TeleportClient) newSessionEnv() map[string]string { - env := map[string]string{ - teleport.SSHSessionWebProxyAddr: tc.WebProxyAddr, - } - if tc.SessionID != "" { - env[sshutils.SessionEnvVar] = tc.SessionID - } - - maps.Copy(env, tc.ExtraEnvs) - return env -} - // getProxyLogin determines which SSH principal to use when connecting to proxy. func (tc *TeleportClient) getProxySSHPrincipal() string { if tc.ProxySSHPrincipal != "" { diff --git a/lib/client/client.go b/lib/client/client.go index c2ff704c1dac2..6875f7f721621 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -28,7 +28,6 @@ import ( "net" "os" "path/filepath" - "strconv" "strings" "sync" "time" @@ -388,10 +387,10 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co return nc, nil } -// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr +// RunInteractiveShell creates or joins an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, joinSessionID string, joinMode types.SessionParticipantMode, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", @@ -399,28 +398,21 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session ) defer span.End() - env := c.TC.newSessionEnv() - env[teleport.EnvSSHJoinMode] = string(mode) - env[teleport.EnvSSHSessionReason] = c.TC.Config.Reason - env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(c.TC.Config.DisplayParticipantRequirements) - encoded, err := json.Marshal(&c.TC.Config.Invited) - if err != nil { - return trace.Wrap(err) + sessionParams := &tracessh.SessionParams{ + WebProxyAddr: c.WebProxyAddr(), + Reason: c.TC.Config.Reason, + Invited: c.TC.Config.Invited, + DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements, + JoinSessionID: joinSessionID, + JoinMode: joinMode, } - env[teleport.EnvSSHSessionInvited] = string(encoded) - // Overwrite "SSH_SESSION_WEBPROXY_ADDR" with the public addr reported by the proxy. Otherwise, - // this would be set to the localhost addr (tc.WebProxyAddr) used for Web UI client connections. - if c.ProxyPublicAddr != "" && c.TC.WebProxyAddr != c.ProxyPublicAddr { - env[teleport.SSHSessionWebProxyAddr] = c.ProxyPublicAddr - } - - nodeSession, err := newSession(ctx, c, sessToJoin, env, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences) + nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences) if err != nil { return trace.Wrap(err) } - if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil { + if err = nodeSession.runShell(ctx, sessionParams, beforeStart, c.TC.OnShellCreated); err != nil { var exitErr *ssh.ExitError var exitMissingErr *ssh.ExitMissingError switch err := trace.Unwrap(err); { @@ -616,13 +608,19 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R } } - nodeSession, err := newSession(ctx, c, nil, c.TC.newSessionEnv(), c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences) + sessionParams := &tracessh.SessionParams{ + WebProxyAddr: c.WebProxyAddr(), + Reason: c.TC.Config.Reason, + Invited: c.TC.Config.Invited, + DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements, + } + + nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences) if err != nil { return trace.Wrap(err) } defer nodeSession.Close() - - err = nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) + err = nodeSession.runCommand(ctx, sessionParams, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) if err != nil { c.TC.SetExitStatus(getExitStatus(err)) } @@ -745,7 +743,7 @@ func newClientConn( } // TransferFiles transfers files over SFTP. -func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error { +func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config, moderatedSessionID string) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/TransferFiles", @@ -753,7 +751,7 @@ func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error ) defer span.End() - if err := cfg.TransferFiles(ctx, c.Client.Client); err != nil { + if err := cfg.TransferFiles(ctx, c.Client, moderatedSessionID); err != nil { // TODO(tross): DELETE IN 19.0.0 - Older versions of Teleport would return // a trace.BadParameter error when ~user path expansion was rejected, and // reauthentication logic is attempted on BadParameter errors. @@ -1029,3 +1027,13 @@ func GetPaginatedSessions(ctx context.Context, fromUTC, toUTC time.Time, pageSiz } return sessions, nil } + +// WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server. +func (c *NodeClient) WebProxyAddr() string { + // Prioritize the public addr reported by the proxy. Otherwise, this would + // return the localhost addr used for Web UI client connections. + if c.ProxyPublicAddr != "" { + return c.ProxyPublicAddr + } + return c.TC.WebProxyAddr +} diff --git a/lib/client/client_test.go b/lib/client/client_test.go index 81313530b6858..b747c1992e0fa 100644 --- a/lib/client/client_test.go +++ b/lib/client/client_test.go @@ -36,7 +36,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/lib/observability/tracing" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" ) @@ -48,12 +47,13 @@ func TestHelperFunctions(t *testing.T) { func TestNewSession(t *testing.T) { nc := &NodeClient{ + TC: &TeleportClient{}, Tracer: tracing.NoopProvider().Tracer("test"), } ctx := context.Background() // defaults: - ses, err := newSession(ctx, nc, nil, nil, nil, nil, nil, true) + ses, err := newSession(ctx, nc, nil, nil, nil, nil, true) require.NoError(t, err) require.NotNil(t, ses) require.Equal(t, nc, ses.NodeClient()) @@ -61,14 +61,6 @@ func TestNewSession(t *testing.T) { require.Equal(t, os.Stderr, ses.terminal.Stderr()) require.Equal(t, os.Stdout, ses.terminal.Stdout()) require.Equal(t, os.Stdin, ses.terminal.Stdin()) - - // pass environ map - env := map[string]string{ - sshutils.SessionEnvVar: "session-id", - } - ses, err = newSession(ctx, nc, nil, env, nil, nil, nil, true) - require.NoError(t, err) - require.NotNil(t, ses) } // TestProxyConnection verifies that client or server-side disconnect diff --git a/lib/client/session.go b/lib/client/session.go index 96b2ee2926a83..4e5dc8eab6f9b 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -20,6 +20,7 @@ package client import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -27,6 +28,7 @@ import ( "net" "os" "os/signal" + "strconv" "strings" "sync" "sync/atomic" @@ -102,8 +104,7 @@ type NodeSession struct { // of another user func newSession(ctx context.Context, client *NodeClient, - joinSession types.SessionTracker, - env map[string]string, + sessionParams *tracessh.SessionParams, stdin io.Reader, stdout io.Writer, stderr io.Writer, @@ -117,8 +118,21 @@ func newSession(ctx context.Context, return nil, trace.Wrap(err) } - if env == nil { - env = make(map[string]string) + env := make(map[string]string) + maps.Copy(env, client.TC.ExtraEnvs) + + // TODO(Joerger): DELETE IN v20.0.0 - session params are provided in the session + // request as extra data rather than env vars. + if sessionParams != nil { + env[teleport.SSHSessionWebProxyAddr] = sessionParams.WebProxyAddr + env[teleport.EnvSSHJoinMode] = string(sessionParams.JoinMode) + env[teleport.EnvSSHSessionReason] = sessionParams.Reason + env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(sessionParams.DisplayParticipantRequirements) + encoded, err := json.Marshal(&sessionParams.Invited) + if err != nil { + return nil, trace.Wrap(err) + } + env[teleport.EnvSSHSessionInvited] = string(encoded) } ns := &NodeSession{ @@ -130,11 +144,11 @@ func newSession(ctx context.Context, terminal: term, shouldClearOnExit: client.FIPSEnabled || isFIPS(), } - // if we're joining an existing session, we need to assume that session's - // existing/current terminal size: - if joinSession != nil { - sessionID := joinSession.GetSessionID() - terminalSize, err := client.GetRemoteTerminalSize(ctx, sessionID) + + if sessionParams != nil && sessionParams.JoinSessionID != "" { + // if we're joining an existing session, we need to assume that session's + // existing/current terminal size: + terminalSize, err := client.GetRemoteTerminalSize(ctx, sessionParams.JoinSessionID) if err != nil { return nil, trace.Wrap(err) } @@ -144,10 +158,10 @@ func newSession(ctx context.Context, if err != nil { log.ErrorContext(ctx, "Failed to resize terminal", "error", err) } - } - ns.env[sshutils.SessionEnvVar] = sessionID + // TODO(Joerger): DELETE IN v20.0.0 - session env var is no longer used for session joining. + ns.env[sshutils.SessionEnvVar] = sessionParams.JoinSessionID } // Close the Terminal when finished. @@ -172,7 +186,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, sessionParams *tracessh.SessionParams, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -180,7 +194,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func( ) defer span.End() - session, err := ns.createServerSession(ctx) + session, err := ns.createServerSession(ctx, nil) if err != nil { return trace.Wrap(err) } @@ -192,7 +206,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func( type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context, sessionParams *tracessh.SessionParams) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -200,7 +214,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi ) defer span.End() - sess, err := ns.nodeClient.Client.NewSession(ctx) + sess, err := ns.nodeClient.Client.NewSessionWithParams(ctx, sessionParams) if err != nil { return nil, trace.Wrap(err) } @@ -266,7 +280,7 @@ func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGett // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, sessionCallback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, sessionParams *tracessh.SessionParams, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -280,7 +294,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio termType = teleport.SafeTerminalType } // create the server-side session: - sess, err := ns.createServerSession(ctx) + sess, err := ns.createServerSession(ctx, sessionParams) if err != nil { return trace.Wrap(err) } @@ -306,6 +320,11 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio ns.watchSignals(remoteTerm) } + mode := types.SessionPeerMode + if sessionParams != nil && sessionParams.JoinMode != "" { + mode = sessionParams.JoinMode + } + // start piping input into the remote shell and pipe the output from // the remote shell into stdout: ns.pipeInOut(ctx, remoteTerm, mode, sess) @@ -510,8 +529,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { +func (ns *NodeSession) runShell(ctx context.Context, sessionParams *tracessh.SessionParams, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { + return ns.interactiveSession(ctx, sessionParams, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, session: s, @@ -539,7 +558,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, sessionParams *tracessh.SessionParams, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand", @@ -553,7 +572,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, sessionParams, func(s *tracessh.Session, term io.ReadWriteCloser) error { err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) @@ -580,7 +599,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, func(s *tracessh.Session) error { + return ns.regularSession(ctx, sessionParams, func(s *tracessh.Session) error { errCh := make(chan error, 1) go func() { errCh <- s.Run(ctx, strings.Join(cmd, " ")) diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index ebc4704719121..fe7e1187b5767 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -20,6 +20,7 @@ package srv import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -53,6 +54,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/envutils" @@ -314,6 +316,9 @@ type ServerContext struct { // term holds PTY if it was requested by the session. term Terminal + // sessionParams are parameters associated with this server session. + sessionParams *tracessh.SessionParams + // session holds the active session (if there's an active one). session *session @@ -436,7 +441,7 @@ type ServerContext struct { // the ServerContext is closed. The ctx parameter should be a child of the ctx // associated with the scope of the parent ConnectionContext to ensure that // cancellation of the ConnectionContext propagates to the ServerContext. -func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) { +func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, sessionParams *tracessh.SessionParams, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) { recConfig, err := srv.GetAccessPoint().GetSessionRecordingConfig(ctx) if err != nil { return nil, trace.Wrap(err) @@ -476,6 +481,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s id: int(atomic.AddInt32(&ctxID, int32(1))), env: make(map[string]string), srv: srv, + sessionParams: sessionParams, ExecResultCh: make(chan ExecResult, 10), SubsystemResultCh: make(chan SubsystemResult, 10), ClusterName: parent.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName], @@ -587,31 +593,6 @@ func (c *ServerContext) ID() int { return c.id } -// GetJoinParams gets join params if they are set. -// -// These params (env vars) are set synchronously between the "session" channel request -// and the "shell" / "exec" channel request. Therefore, these params are only guaranteed -// to be accurately set during and after the "shell" / "exec" channel request. -// -// TODO(Joerger): Rather than relying on the out-of-band env var params, we should -// provide session params upfront as extra data in the session channel request. -func (c *ServerContext) GetJoinParams() (string, types.SessionParticipantMode) { - c.mu.RLock() - defer c.mu.RUnlock() - - sid, found := c.getEnvLocked(sshutils.SessionEnvVar) - if !found { - return "", "" - } - - mode := types.SessionPeerMode // default - if modeString, found := c.getEnvLocked(teleport.EnvSSHJoinMode); found { - mode = types.SessionParticipantMode(modeString) - } - - return sid, mode -} - // SessionID returns the ID of the session in the context. // // This value is not set until during and after the "shell" / "exec" channel request. @@ -680,19 +661,35 @@ func (c *ServerContext) SetEnv(key, val string) { c.mu.Unlock() } -// GetEnv returns a environment variable within this context. -func (c *ServerContext) GetEnv(key string) (string, bool) { +// GetSessionParams gets session params for the current session. +func (c *ServerContext) GetSessionParams() tracessh.SessionParams { c.mu.RLock() defer c.mu.RUnlock() - return c.getEnvLocked(key) -} -func (c *ServerContext) getEnvLocked(key string) (string, bool) { - val, ok := c.env[key] - if ok { - return val, true + // Teleport ssh clients should provide session params upfront in the session channel request. + if c.sessionParams != nil { + return *c.sessionParams + } + + // If this is an old client, it will provide session params from + // env variables sometime between the session channel request and shell request. + // TODO(Joerger): DELETE IN v20.0.0 - just return empty params for an old Teleport client / openSSH client session. + sessionParams := tracessh.SessionParams{ + WebProxyAddr: c.env[teleport.SSHSessionWebProxyAddr], + Reason: c.env[teleport.EnvSSHSessionReason], + DisplayParticipantRequirements: utils.AsBool(c.env[teleport.EnvSSHSessionDisplayParticipantRequirements]), + JoinSessionID: c.env[sshutils.SessionEnvVar], + JoinMode: types.SessionParticipantMode(c.env[teleport.EnvSSHJoinMode]), + ModeratedSessionID: c.env[sftp.EnvModeratedSessionID], + } + + if invitedUsers := c.env[teleport.EnvSSHSessionInvited]; invitedUsers != "" { + if err := json.Unmarshal([]byte(invitedUsers), &sessionParams.Invited); err != nil { + slog.WarnContext(context.Background(), "Failed to parse invited users", "error", err) + } } - return c.Parent().GetEnv(key) + + return sessionParams } // setSession sets the context's session @@ -1166,11 +1163,15 @@ func buildEnvironment(ctx *ServerContext) []string { } // Set some Teleport specific environment variables: SSH_TELEPORT_USER, - // SSH_TELEPORT_HOST_UUID, and SSH_TELEPORT_CLUSTER_NAME. + // SSH_TELEPORT_HOST_UUID, SSH_TELEPORT_CLUSTER_NAME, and SSH_SESSION_WEBPROXY_ADDR. env.AddTrusted(teleport.SSHTeleportHostUUID, ctx.srv.ID()) env.AddTrusted(teleport.SSHTeleportClusterName, ctx.ClusterName) env.AddTrusted(teleport.SSHTeleportUser, ctx.Identity.TeleportUser) + if ctx.GetSessionParams().WebProxyAddr != "" { + env.AddTrusted(teleport.SSHSessionWebProxyAddr, ctx.GetSessionParams().WebProxyAddr) + } + // At the end gather all dynamically defined environment variables ctx.VisitEnv(env.AddUnique) diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index f13a700082cef..dbe46517b3f83 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -574,6 +574,10 @@ func (s *Server) Serve() { config.KeyExchanges = s.kexAlgorithms config.MACs = s.macAlgorithms + // Set the server version to Teleport to enable tracing and other Teleport + // specific features like joining. + config.ServerVersion = sshutils.SSHVersionPrefix + netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(s.Context()) if err != nil { s.logger.ErrorContext(s.Context(), "Unable to fetch cluster config", "error", err) @@ -889,7 +893,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha // Create context for this channel. This context will be closed when // forwarding is complete. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, nil) if err != nil { if err := nch.Reject(ssh.ConnectionFailed, "failed to open server context"); err != nil { s.logger.ErrorContext(ctx, "Error rejecting forwarded-tcpip channel", "error", err) @@ -999,7 +1003,7 @@ func (s *Server) checkTCPIPForwardRequest(ctx context.Context, r *ssh.Request) e // RBAC checks are only necessary when connecting to an agentless node if s.targetServer.IsOpenSSHNode() { - scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext, nil) if err != nil { return err } @@ -1062,7 +1066,7 @@ func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) { func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, nil) if err != nil { s.logger.ErrorContext(ctx, "Unable to create connection context", "error", err) s.stderrWrite(ctx, ch, "Unable to create connection context.") @@ -1114,12 +1118,22 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r // the remote host. Once the session channel has been established, this function's loop handles // all the "exec", "subsystem" and "shell" requests. func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { + // sessionParams will not be passed by old clients (< v19) or OpenSSH clients. + sessionParams, err := tracessh.ParseSessionParams(nch.ExtraData()) + if err != nil { + s.logger.ErrorContext(ctx, "Failed to parse request data", "data", string(nch.ExtraData()), "error", err) + if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + // Create context for this channel. This context will be closed when the // session request is complete. // There is no need for the forwarding server to initiate disconnects, // based on teleport business logic, because this logic is already // done on the server's terminating side. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, sessionParams) if err != nil { s.logger.WarnContext(ctx, "Server context setup failed", "error", err) if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("server context setup failed: %v", err)); err != nil { @@ -1140,7 +1154,7 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { // create the remote session channel before accepting the local // channel request; this allows us to propagate the rejection // reason/message in the event the channel is rejected. - remoteSession, err := s.remoteClient.NewSession(ctx) + remoteSession, err := s.remoteClient.NewSessionWithParams(ctx, sessionParams) if err != nil { s.logger.WarnContext(ctx, "Remote session open failed", "error", err) reason, msg := ssh.ConnectionFailed, fmt.Sprintf("remote session open failed: %v", err) diff --git a/lib/srv/git/forward.go b/lib/srv/git/forward.go index 381624c096a4d..9bf7ff21c1532 100644 --- a/lib/srv/git/forward.go +++ b/lib/srv/git/forward.go @@ -374,7 +374,7 @@ func (s *ForwardServer) onConnection(ctx context.Context, ccx *sshutils.Connecti // TODO(greedy52) decouple from srv.NewServerContext. We only need // connection monitoring. - serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx) + serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx, nil) if err != nil { return nil, trace.Wrap(err) } @@ -400,11 +400,18 @@ func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionC return } + // sessionParams will not be passed by old clients (< v19) or OpenSSH clients. + sessionParams, err := tracessh.ParseSessionParams(nch.ExtraData()) + if err != nil { + s.reply.RejectWithAcceptError(ctx, nch, err) + return + } + if s.remoteClient == nil { s.reply.RejectWithNewRemoteSessionError(ctx, nch, trace.NotFound("missing remote client")) return } - remoteSession, err := s.remoteClient.NewSession(ctx) + remoteSession, err := s.remoteClient.NewSessionWithParams(ctx, sessionParams) if err != nil { s.reply.RejectWithNewRemoteSessionError(ctx, nch, err) return diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index c7c77ccf3ee0c..bdd448705f96f 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1215,7 +1215,7 @@ func (s *Server) getNetworkingProcess(scx *srv.ServerContext) (*networking.Proce // the server connection is closed. func (s *Server) startNetworkingProcess(scx *srv.ServerContext) (*networking.Process, error) { // Create context for the networking process. - nsctx, err := srv.NewServerContext(context.Background(), scx.ConnectionContext, s, scx.Identity) + nsctx, err := srv.NewServerContext(context.Background(), scx.ConnectionContext, s, scx.Identity, nil) if err != nil { return nil, trace.Wrap(err) } @@ -1400,7 +1400,7 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont s.rejectChannel(ctx, nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(ctx, ccx, identityContext, ch, requests) + go s.handleSessionRequests(ctx, ccx, identityContext, nil, ch, requests) return default: s.rejectChannel(ctx, nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) @@ -1443,6 +1443,15 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont } decr = d } + + // SessionParams are not passed by old clients (