diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index 714a8607f429e..28f0f8d014726 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -22,6 +22,7 @@ import ( "sync" "sync/atomic" + "github.com/google/uuid" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" @@ -29,6 +30,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" ) // Client is a wrapper around ssh.Client that adds tracing support. @@ -184,9 +186,65 @@ func (c *Client) OpenChannel( }, reqs, err } -// NewSession creates a new SSH session that is passed tracing context -// so that spans may be correlated properly over the ssh connection. +// SessionParams are session parameters supported by Teleport to provide additional +// session context or parameters to the server. +type SessionParams struct { + // WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server. + WebProxyAddr string + // Reason is a reason attached to started sessions meant to describe their intent. + Reason string + // Invited is a list of people invited to a session. + Invited []string + // DisplayParticipantRequirements is set if debug information about participants requirements + // should be printed in moderated sessions. + DisplayParticipantRequirements bool + // JoinSessionID is the ID of a session to join. + JoinSessionID string + // JoinMode is the participant mode to join the session with. + // Required if JoinSessionID is set. + JoinMode types.SessionParticipantMode + // ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session + // to check for valid FileTransferRequests. + ModeratedSessionID string +} + +// ParseSessionParams unmarshals session parameters which have been [ssh.Marshal]ed by the client +// and provided as extra data in the session channel request. If the provided data is empty, nil params +// will be returned with a nil error. +func ParseSessionParams(data []byte) (*SessionParams, error) { + if len(data) == 0 { + return nil, nil + } + + var params SessionParams + if err := ssh.Unmarshal(data, ¶ms); err != nil { + return nil, trace.Wrap(err) + } + + if params.JoinSessionID != "" { + if _, err := uuid.Parse(params.JoinSessionID); err != nil { + return nil, trace.Wrap(err, "failed to parse join session ID: %v", params.JoinSessionID) + } + + switch params.JoinMode { + case types.SessionModeratorMode, types.SessionObserverMode, types.SessionPeerMode: + default: + return nil, trace.BadParameter("Unrecognized session participant mode: %q", params.JoinMode) + } + } + + return ¶ms, nil +} + +// NewSession creates a new SSH session. This session is passed a tracing context so that +// spans may be correlated properly over the ssh connection. func (c *Client) NewSession(ctx context.Context) (*Session, error) { + return c.NewSessionWithParams(ctx, nil) +} + +// NewSessionWithParams creates a new SSH session with the given (optional) params. This session is +// passed a tracing context so that spans may be correlated properly over the ssh connection. +func (c *Client) NewSessionWithParams(ctx context.Context, sessionParams *SessionParams) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -213,9 +271,16 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { contexts: make(map[string][]context.Context), } + // If we are connected to a Teleport server, send session params in the session request. + // If the server does not support session parameters in the extra data, it will be ignored. + var sessionData []byte + if sessionParams != nil && c.capability == tracingSupported { + sessionData = ssh.Marshal(sessionParams) + } + // open a session manually so we can take ownership of the // requests chan - ch, reqs, err := wrapper.OpenChannel("session", nil) + ch, reqs, err := wrapper.OpenChannel("session", sessionData) if err != nil { return nil, trace.Wrap(err) } @@ -236,7 +301,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { } // RequestHandlerFn is an ssh request handler function. -type RequestHandlerFn func(ctx context.Context, ch *ssh.Request) +type RequestHandlerFn func(ctx context.Context, req *ssh.Request) // HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the // provided type within a session. If the type is already being handled, an error is returned. diff --git a/constants.go b/constants.go index 47b4defc1846f..c52e95ad3df36 100644 --- a/constants.go +++ b/constants.go @@ -831,9 +831,22 @@ const ( CurrentSessionIDRequest = "current-session-id@goteleport.com" // SessionIDQueryRequest is sent by clients to ask servers if they - // will generate their own session ID when a new session is created. + // will generate and share their own session ID when a new session + // is started (session and exec/shell channels accepted). + // + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. SessionIDQueryRequest = "session-id-query@goteleport.com" + // SessionIDQueryRequestV2 is sent by clients to ask servers if they + // will generate and share their own session ID when a new session + // channel is accepted, rather than when the shell/exec channel is. + // + // TODO(Joerger): DELETE IN v21.0.0 + // all v19+ servers set the session ID directly after accepting the session channel. + // clients should stop checking in v21, and servers should stop responding to the query in v22. + SessionIDQueryRequestV2 = "session-id-query-v2@goteleport.com" + // ForceTerminateRequest is an SSH request to forcefully terminate a session. ForceTerminateRequest = "x-teleport-force-terminate" diff --git a/integration/integration_test.go b/integration/integration_test.go index b1bba41d9fb88..1dc98c0efd614 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -149,8 +149,8 @@ func TestIntegrations(t *testing.T) { t.Run("BPFSessionDifferentiation", suite.bind(testBPFSessionDifferentiation)) t.Run("ClientIdleConnection", suite.bind(testClientIdleConnection)) t.Run("CmdLabels", suite.bind(testCmdLabels)) + t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters)) t.Run("ControlMaster", suite.bind(testControlMaster)) - t.Run("X11Forwarding", suite.bind(testX11Forwarding)) t.Run("CustomReverseTunnel", suite.bind(testCustomReverseTunnel)) t.Run("DataTransfer", suite.bind(testDataTransfer)) t.Run("DifferentPinnedIP", suite.bind(testDifferentPinnedIP)) @@ -200,12 +200,12 @@ func TestIntegrations(t *testing.T) { t.Run("TrustedClustersRoleMapChanges", suite.bind(testTrustedClustersRoleMapChanges)) t.Run("TrustedClustersWithLabels", suite.bind(testTrustedClustersWithLabels)) t.Run("TrustedClustersSkipNameValidation", suite.bind(testTrustedClustersSkipNameValidation)) - t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters)) t.Run("TrustedTunnelNode", suite.bind(testTrustedTunnelNode)) t.Run("TwoClustersProxy", suite.bind(testTwoClustersProxy)) t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel)) t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy)) t.Run("WindowChange", suite.bind(testWindowChange)) + t.Run("X11Forwarding", suite.bind(testX11Forwarding)) } // testDifferentPinnedIP tests connection is rejected when source IP doesn't match the pinned one @@ -522,6 +522,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } } + // Test streaming events and recording. capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID) findByType := func(et string) apievents.AuditEvent { @@ -532,19 +533,6 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } return nil } - // helper that asserts that a session event is also included in the - // general audit log. - requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) { - t.Helper() - auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ - To: time.Now(), - EventTypes: []string{sessionEvent.GetType()}, - }) - require.NoError(t, err) - require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { - return ae.GetID() == sessionEvent.GetID() - })) - } // there should always be 'session.start' event (and it must be first) first := sessionEvents[0].(*apievents.SessionStart) @@ -552,19 +540,16 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { require.Equal(t, first, start) require.Equal(t, sessionID, start.SessionID) require.NotEmpty(t, start.TerminalSize) - requireInAuditLog(t, start) // there should always be 'session.end' event end := findByType(events.SessionEndEvent).(*apievents.SessionEnd) require.NotNil(t, end) require.Equal(t, sessionID, end.SessionID) - requireInAuditLog(t, end) // there should always be 'session.leave' event leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave) require.NotNil(t, leave) require.Equal(t, sessionID, leave.SessionID) - requireInAuditLog(t, leave) // all of them should have a proper time for _, e := range sessionEvents { @@ -575,6 +560,31 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { recorded := replaceNewlines(capturedStream) require.Regexp(t, ".*exit.*", recorded) require.Regexp(t, ".*echo hi.*", recorded) + + sessionEvents, _, err = site.SearchEvents(ctx, events.SearchEventsRequest{ + From: time.Time{}, + To: time.Now(), + EventTypes: []string{ + events.SessionStartEvent, + events.SessionLeaveEvent, + events.SessionEndEvent, + }, + }) + require.NoError(t, err) + + // Check that the events found above in the session stream show up in the backend. + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == start.GetID() + }), "expected session events to contain session.start event") + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == end.GetID() + }), "expected session events to contain session.end event") + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == leave.GetID() + }), "expected session events to contain session.leave event") + + // Ensure there are no duplicate events, e.g. from proxy recording mode. + require.Len(t, sessionEvents, 3, "%d unexpected duplicate events", len(sessionEvents)-4) }) } } @@ -1120,7 +1130,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { ) assert.NoError(t, err) - errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) + errCh <- nodeClient.RunInteractiveShell(ctx, "", "", nil) assert.NoError(t, nodeClient.Close()) }() @@ -7996,7 +8006,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { isNilOrEOFErr(t, transferSess.Close()) }) - err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + err = transferSess.Setenv(ctx, string(telesftp.EnvModeratedSessionID), sessTracker.GetSessionID()) require.NoError(t, err) err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem) @@ -8058,7 +8068,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { require.NoError(t, transferSess.Close()) }) - err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + err = transferSess.Setenv(ctx, string(telesftp.EnvModeratedSessionID), sessTracker.GetSessionID()) require.NoError(t, err) // Test that only operations needed to complete the download diff --git a/lib/client/api.go b/lib/client/api.go index 76df566e80ad6..048cf40037fef 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -94,7 +94,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -2303,7 +2302,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // Reuse the existing nodeClient we connected above. return nodeClient.RunCommand(ctx, command) } - return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)) + return trace.Wrap(nodeClient.RunInteractiveShell(ctx, "", "", nil)) } func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error { @@ -2457,7 +2456,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan } // running shell with a given session means "join" it: - err = nc.RunInteractiveShell(ctx, mode, session, beforeStart) + err = nc.RunInteractiveShell(ctx, sessionID.String(), mode, beforeStart) return trace.Wrap(err) } @@ -3155,20 +3154,6 @@ func (tc *TeleportClient) writeCommandResults(nodes []execResult) error { return nil } -func (tc *TeleportClient) newSessionEnv() map[string]string { - env := map[string]string{ - teleport.SSHSessionWebProxyAddr: tc.WebProxyAddr, - } - if tc.SessionID != "" { - env[sshutils.SessionEnvVar] = tc.SessionID - } - - for key, val := range tc.ExtraEnvs { - env[key] = val - } - return env -} - // getProxyLogin determines which SSH principal to use when connecting to proxy. func (tc *TeleportClient) getProxySSHPrincipal() string { if tc.ProxySSHPrincipal != "" { diff --git a/lib/client/client.go b/lib/client/client.go index bb3aca78fc330..4dfed45e5e074 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -28,7 +28,6 @@ import ( "net" "os" "path/filepath" - "strconv" "strings" "sync" "time" @@ -391,10 +390,10 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co return nc, nil } -// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr +// RunInteractiveShell creates or joins an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, joinSessionID string, joinMode types.SessionParticipantMode, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", @@ -402,28 +401,21 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session ) defer span.End() - env := c.TC.newSessionEnv() - env[teleport.EnvSSHJoinMode] = string(mode) - env[teleport.EnvSSHSessionReason] = c.TC.Config.Reason - env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(c.TC.Config.DisplayParticipantRequirements) - encoded, err := json.Marshal(&c.TC.Config.Invited) - if err != nil { - return trace.Wrap(err) + sessionParams := &tracessh.SessionParams{ + WebProxyAddr: c.WebProxyAddr(), + Reason: c.TC.Config.Reason, + Invited: c.TC.Config.Invited, + DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements, + JoinSessionID: joinSessionID, + JoinMode: joinMode, } - env[teleport.EnvSSHSessionInvited] = string(encoded) - // Overwrite "SSH_SESSION_WEBPROXY_ADDR" with the public addr reported by the proxy. Otherwise, - // this would be set to the localhost addr (tc.WebProxyAddr) used for Web UI client connections. - if c.ProxyPublicAddr != "" && c.TC.WebProxyAddr != c.ProxyPublicAddr { - env[teleport.SSHSessionWebProxyAddr] = c.ProxyPublicAddr - } - - nodeSession, err := newSession(ctx, c, sessToJoin, env, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences) + nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences) if err != nil { return trace.Wrap(err) } - if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil { + if err = nodeSession.runShell(ctx, sessionParams, beforeStart, c.TC.OnShellCreated); err != nil { var exitErr *ssh.ExitError var exitMissingErr *ssh.ExitMissingError switch err := trace.Unwrap(err); { @@ -619,13 +611,19 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R } } - nodeSession, err := newSession(ctx, c, nil, c.TC.newSessionEnv(), c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences) + sessionParams := &tracessh.SessionParams{ + WebProxyAddr: c.WebProxyAddr(), + Reason: c.TC.Config.Reason, + Invited: c.TC.Config.Invited, + DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements, + } + + nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences) if err != nil { return trace.Wrap(err) } defer nodeSession.Close() - - err = nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) + err = nodeSession.runCommand(ctx, sessionParams, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) if err != nil { c.TC.SetExitStatus(getExitStatus(err)) } @@ -1009,3 +1007,13 @@ func GetPaginatedSessions(ctx context.Context, fromUTC, toUTC time.Time, pageSiz } return sessions, nil } + +// WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server. +func (c *NodeClient) WebProxyAddr() string { + // Prioritize the public addr reported by the proxy. Otherwise, this would + // return the localhost addr used for Web UI client connections. + if c.ProxyPublicAddr != "" { + return c.ProxyPublicAddr + } + return c.TC.WebProxyAddr +} diff --git a/lib/client/client_test.go b/lib/client/client_test.go index f53c6ec317cb1..5ee7659c4233a 100644 --- a/lib/client/client_test.go +++ b/lib/client/client_test.go @@ -36,7 +36,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/lib/observability/tracing" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" ) @@ -48,12 +47,13 @@ func TestHelperFunctions(t *testing.T) { func TestNewSession(t *testing.T) { nc := &NodeClient{ + TC: &TeleportClient{}, Tracer: tracing.NoopProvider().Tracer("test"), } ctx := context.Background() // defaults: - ses, err := newSession(ctx, nc, nil, nil, nil, nil, nil, true) + ses, err := newSession(ctx, nc, nil, nil, nil, nil, true) require.NoError(t, err) require.NotNil(t, ses) require.Equal(t, nc, ses.NodeClient()) @@ -61,14 +61,6 @@ func TestNewSession(t *testing.T) { require.Equal(t, os.Stderr, ses.terminal.Stderr()) require.Equal(t, os.Stdout, ses.terminal.Stdout()) require.Equal(t, os.Stdin, ses.terminal.Stdin()) - - // pass environ map - env := map[string]string{ - sshutils.SessionEnvVar: "session-id", - } - ses, err = newSession(ctx, nc, nil, env, nil, nil, nil, true) - require.NoError(t, err) - require.NotNil(t, ses) } // TestProxyConnection verifies that client or server-side disconnect diff --git a/lib/client/session.go b/lib/client/session.go index 69cdbddb0c6ec..65556e9f8a029 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -20,12 +20,15 @@ package client import ( "context" + "encoding/json" "errors" "fmt" "io" + "maps" "net" "os" "os/signal" + "strconv" "strings" "sync" "sync/atomic" @@ -101,8 +104,7 @@ type NodeSession struct { // of another user func newSession(ctx context.Context, client *NodeClient, - joinSession types.SessionTracker, - env map[string]string, + sessionParams *tracessh.SessionParams, stdin io.Reader, stdout io.Writer, stderr io.Writer, @@ -116,8 +118,21 @@ func newSession(ctx context.Context, return nil, trace.Wrap(err) } - if env == nil { - env = make(map[string]string) + env := make(map[string]string) + maps.Copy(env, client.TC.ExtraEnvs) + + // TODO(Joerger): DELETE IN v20.0.0 - session params are provided in the session + // request as extra data rather than env vars. + if sessionParams != nil { + env[teleport.SSHSessionWebProxyAddr] = sessionParams.WebProxyAddr + env[teleport.EnvSSHJoinMode] = string(sessionParams.JoinMode) + env[teleport.EnvSSHSessionReason] = sessionParams.Reason + env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(sessionParams.DisplayParticipantRequirements) + encoded, err := json.Marshal(&sessionParams.Invited) + if err != nil { + return nil, trace.Wrap(err) + } + env[teleport.EnvSSHSessionInvited] = string(encoded) } ns := &NodeSession{ @@ -129,11 +144,11 @@ func newSession(ctx context.Context, terminal: term, shouldClearOnExit: client.FIPSEnabled || isFIPS(), } - // if we're joining an existing session, we need to assume that session's - // existing/current terminal size: - if joinSession != nil { - sessionID := joinSession.GetSessionID() - terminalSize, err := client.GetRemoteTerminalSize(ctx, sessionID) + + if sessionParams != nil && sessionParams.JoinSessionID != "" { + // if we're joining an existing session, we need to assume that session's + // existing/current terminal size: + terminalSize, err := client.GetRemoteTerminalSize(ctx, sessionParams.JoinSessionID) if err != nil { return nil, trace.Wrap(err) } @@ -143,10 +158,10 @@ func newSession(ctx context.Context, if err != nil { log.ErrorContext(ctx, "Failed to resize terminal", "error", err) } - } - ns.env[sshutils.SessionEnvVar] = sessionID + // TODO(Joerger): DELETE IN v20.0.0 - session env var is no longer used for session joining. + ns.env[sshutils.SessionEnvVar] = sessionParams.JoinSessionID } // Close the Terminal when finished. @@ -171,7 +186,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, sessionParams *tracessh.SessionParams, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -179,7 +194,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func( ) defer span.End() - session, err := ns.createServerSession(ctx) + session, err := ns.createServerSession(ctx, nil) if err != nil { return trace.Wrap(err) } @@ -191,7 +206,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func( type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context, sessionParams *tracessh.SessionParams) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -199,7 +214,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi ) defer span.End() - sess, err := ns.nodeClient.Client.NewSession(ctx) + sess, err := ns.nodeClient.Client.NewSessionWithParams(ctx, sessionParams) if err != nil { return nil, trace.Wrap(err) } @@ -267,7 +282,7 @@ func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGett // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, sessionCallback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, sessionParams *tracessh.SessionParams, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -281,7 +296,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio termType = teleport.SafeTerminalType } // create the server-side session: - sess, err := ns.createServerSession(ctx) + sess, err := ns.createServerSession(ctx, sessionParams) if err != nil { return trace.Wrap(err) } @@ -307,6 +322,11 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio ns.watchSignals(remoteTerm) } + mode := types.SessionPeerMode + if sessionParams != nil && sessionParams.JoinMode != "" { + mode = sessionParams.JoinMode + } + // start piping input into the remote shell and pipe the output from // the remote shell into stdout: ns.pipeInOut(ctx, remoteTerm, mode, sess) @@ -511,8 +531,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { +func (ns *NodeSession) runShell(ctx context.Context, sessionParams *tracessh.SessionParams, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { + return ns.interactiveSession(ctx, sessionParams, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, session: s, @@ -540,7 +560,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, sessionParams *tracessh.SessionParams, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand", @@ -554,7 +574,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, sessionParams, func(s *tracessh.Session, term io.ReadWriteCloser) error { err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) @@ -581,7 +601,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, func(s *tracessh.Session) error { + return ns.regularSession(ctx, sessionParams, func(s *tracessh.Session) error { errCh := make(chan error, 1) go func() { errCh <- s.Run(ctx, strings.Join(cmd, " ")) diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 9dffe572b6fab..c4aa5c67c3b34 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -471,13 +471,10 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net DataDir: s.srv.Config.DataDir, Address: params.Address, UseTunnel: useTunnel, - HostUUID: s.srv.ID, + ProxyUUID: s.srv.ID, Emitter: s.srv.Config.Emitter, ParentContext: s.srv.Context, LockWatcher: s.srv.LockWatcher, - TargetID: params.ServerID, - TargetAddr: params.To.String(), - TargetHostname: params.Address, TargetServer: params.TargetServer, Clock: s.clock, } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 2f6b211cb36dc..75bf261e1d377 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -913,13 +913,10 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne Address: params.Address, UseTunnel: UseTunnel(s.logger, targetConn), FIPS: s.srv.FIPS, - HostUUID: s.srv.ID, + ProxyUUID: s.srv.ID, Emitter: s.srv.Config.Emitter, ParentContext: s.srv.Context, LockWatcher: s.srv.LockWatcher, - TargetID: params.ServerID, - TargetAddr: params.To.String(), - TargetHostname: params.Address, TargetServer: params.TargetServer, Clock: s.clock, } diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index fd05ed2931f70..44ad1571c43d5 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -701,7 +701,7 @@ func (h *AuthHandlers) hostKeyCallback(hostname string, remote net.Addr, key ssh ctx := h.c.Server.Context() // For SubKindOpenSSHEICENode we use SSH Keys (EC2 does not support Certificates in ec2.SendSSHPublicKey). - if h.c.Server.TargetMetadata().ServerSubKind == types.SubKindOpenSSHEICENode { + if h.c.Server.GetInfo().GetSubKind() == types.SubKindOpenSSHEICENode { return nil } diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 0324ac96833c0..d43346e097a88 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -20,6 +20,7 @@ package srv import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -51,8 +52,10 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/envutils" @@ -155,6 +158,7 @@ type Server interface { GetClock() clockwork.Clock // GetInfo returns a services.Server that represents this server. + // In the case of the Proxy forwarder, this is the node target. GetInfo() types.Server // UseTunnel used to determine if this node has connected to this cluster @@ -189,8 +193,8 @@ type Server interface { // support or not. GetSELinuxEnabled() bool - // TargetMetadata returns metadata about the session target node. - TargetMetadata() apievents.ServerMetadata + // EventMetadata returns [events.ServerMetadata] for this server. + EventMetadata() apievents.ServerMetadata } // IdentityContext holds all identity information associated with the user @@ -313,6 +317,15 @@ type ServerContext struct { // term holds PTY if it was requested by the session. term Terminal + // sessionParams are parameters associated with this server session. + sessionParams *tracessh.SessionParams + + // newSessionID is set if this server context is going to create a new session. + // This field must be set through [ServerContext.SetNewSessionID] for non-join + // sessions as soon as a session channel is accepted in order to inform + // the client of the to-be session ID. + newSessionID rsession.ID + // session holds the active session (if there's an active one). session *session @@ -435,7 +448,7 @@ type ServerContext struct { // the ServerContext is closed. The ctx parameter should be a child of the ctx // associated with the scope of the parent ConnectionContext to ensure that // cancellation of the ConnectionContext propagates to the ServerContext. -func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) { +func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, sessionParams *tracessh.SessionParams, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) { recConfig, err := srv.GetAccessPoint().GetSessionRecordingConfig(ctx) if err != nil { return nil, trace.Wrap(err) @@ -475,6 +488,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s id: int(atomic.AddInt32(&ctxID, int32(1))), env: make(map[string]string), srv: srv, + sessionParams: sessionParams, ExecResultCh: make(chan ExecResult, 10), SubsystemResultCh: make(chan SubsystemResult, 10), ClusterName: parent.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName], @@ -483,7 +497,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s clientIdleTimeout: clientIdleTimeout, cancelContext: cancelContext, cancel: cancel, - ServerSubKind: srv.TargetMetadata().ServerSubKind, + ServerSubKind: srv.GetInfo().GetSubKind(), } child.Logger = slog.With( @@ -586,31 +600,6 @@ func (c *ServerContext) ID() int { return c.id } -// GetJoinParams gets join params if they are set. -// -// These params (env vars) are set synchronously between the "session" channel request -// and the "shell" / "exec" channel request. Therefore, these params are only guaranteed -// to be accurately set during and after the "shell" / "exec" channel request. -// -// TODO(Joerger): Rather than relying on the out-of-band env var params, we should -// provide session params upfront as extra data in the session channel request. -func (c *ServerContext) GetJoinParams() (string, types.SessionParticipantMode) { - c.mu.RLock() - defer c.mu.RUnlock() - - sid, found := c.getEnvLocked(sshutils.SessionEnvVar) - if !found { - return "", "" - } - - mode := types.SessionPeerMode // default - if modeString, found := c.getEnvLocked(teleport.EnvSSHJoinMode); found { - mode = types.SessionParticipantMode(modeString) - } - - return sid, mode -} - // SessionID returns the ID of the session in the context. // // This value is not set until during and after the "shell" / "exec" channel request. @@ -679,36 +668,56 @@ func (c *ServerContext) SetEnv(key, val string) { c.mu.Unlock() } -// GetEnv returns a environment variable within this context. -func (c *ServerContext) GetEnv(key string) (string, bool) { +// GetSessionParams gets session params for the current session. +func (c *ServerContext) GetSessionParams() tracessh.SessionParams { c.mu.RLock() defer c.mu.RUnlock() - return c.getEnvLocked(key) -} -func (c *ServerContext) getEnvLocked(key string) (string, bool) { - val, ok := c.env[key] - if ok { - return val, true + // Teleport ssh clients should provide session params upfront in the session channel request. + if c.sessionParams != nil { + return *c.sessionParams + } + + // If this is an old client, it will provide session params from + // env variables sometime between the session channel request and shell request. + // TODO(Joerger): DELETE IN v20.0.0 - just return empty params for an old Teleport client / openSSH client session. + sessionParams := tracessh.SessionParams{ + WebProxyAddr: c.env[teleport.SSHSessionWebProxyAddr], + Reason: c.env[teleport.EnvSSHSessionReason], + DisplayParticipantRequirements: utils.AsBool(c.env[teleport.EnvSSHSessionDisplayParticipantRequirements]), + JoinSessionID: c.env[sshutils.SessionEnvVar], + JoinMode: types.SessionParticipantMode(c.env[teleport.EnvSSHJoinMode]), + ModeratedSessionID: c.env[sftp.EnvModeratedSessionID], + } + + if invitedUsers := c.env[teleport.EnvSSHSessionInvited]; invitedUsers != "" { + if err := json.Unmarshal([]byte(invitedUsers), &sessionParams.Invited); err != nil { + slog.WarnContext(context.Background(), "Failed to parse invited users", "error", err) + } } - return c.Parent().GetEnv(key) + + return sessionParams +} + +// SetNewSessionID sets the ID for a new session in this server context. +func (c *ServerContext) SetNewSessionID(ctx context.Context, sid rsession.ID) { + c.mu.Lock() + defer c.mu.Unlock() + c.newSessionID = sid +} + +// GetNewSessionID gets the ID for a new session in this server context. +func (c *ServerContext) GetNewSessionID() rsession.ID { + c.mu.Lock() + defer c.mu.Unlock() + return c.newSessionID } // setSession sets the context's session -func (c *ServerContext) setSession(ctx context.Context, sess *session, ch ssh.Channel) { +func (c *ServerContext) setSession(ctx context.Context, sess *session) { c.mu.Lock() defer c.mu.Unlock() c.session = sess - - // inform the client of the session ID that is being used in a new - // goroutine to reduce latency - go func() { - c.Logger.DebugContext(ctx, "Sending current session ID") - _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID())) - if err != nil { - c.Logger.DebugContext(ctx, "Failed to send the current session ID", "error", err) - } - }() } // getSession returns the context's session @@ -903,7 +912,7 @@ func (c *ServerContext) reportStats(conn utils.Stater) { Type: events.SessionDataEvent, Code: events.SessionDataCode, }, - ServerMetadata: c.srv.TargetMetadata(), + ServerMetadata: c.srv.EventMetadata(), SessionMetadata: c.GetSessionMetadata(), UserMetadata: c.Identity.GetUserMetadata(), ConnectionMetadata: apievents.ConnectionMetadata{ @@ -924,6 +933,15 @@ func (c *ServerContext) reportStats(conn utils.Stater) { serverRX.Add(float64(rxBytes)) } +// ShouldHandleRecording returns whether this server context is responsible for +// recording session events, including session recording, audit events, and session tracking. +func (c *ServerContext) ShouldHandleSessionRecording() bool { + // The only time this server is not responsible for recording the session is when this + // is a Teleport Node with Proxy recording mode turned on, where the forwarding node will + // handle the recording. + return c.srv.Component() != teleport.ComponentNode || !services.IsRecordAtProxy(c.SessionRecordingConfig.GetMode()) +} + func (c *ServerContext) Close() error { // If the underlying connection is holding tracking information, report that // to the audit log at close. @@ -1165,11 +1183,15 @@ func buildEnvironment(ctx *ServerContext) []string { } // Set some Teleport specific environment variables: SSH_TELEPORT_USER, - // SSH_TELEPORT_HOST_UUID, and SSH_TELEPORT_CLUSTER_NAME. + // SSH_TELEPORT_HOST_UUID, SSH_TELEPORT_CLUSTER_NAME, and SSH_SESSION_WEBPROXY_ADDR. env.AddTrusted(teleport.SSHTeleportHostUUID, ctx.srv.ID()) env.AddTrusted(teleport.SSHTeleportClusterName, ctx.ClusterName) env.AddTrusted(teleport.SSHTeleportUser, ctx.Identity.TeleportUser) + if ctx.GetSessionParams().WebProxyAddr != "" { + env.AddTrusted(teleport.SSHSessionWebProxyAddr, ctx.GetSessionParams().WebProxyAddr) + } + // At the end gather all dynamically defined environment variables ctx.VisitEnv(env.AddUnique) diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 2b65d301a0cc9..2fab90d5a1a1c 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -39,10 +39,8 @@ import ( "github.com/gravitational/teleport" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" - "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -102,10 +100,8 @@ func NewExecRequest(ctx *ServerContext, command string) (Exec, error) { }, nil } - // If this is a registered OpenSSH node or proxy recoding mode is - // enabled, execute the command on a remote host. This is used by - // in-memory forwarding nodes. - if types.IsOpenSSHNodeSubKind(ctx.ServerSubKind) || services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { + // If this is a forwarding node, execute the command on a remote host. + if ctx.srv.Component() == teleport.ComponentForwardingNode { return &remoteExec{ ctx: ctx, command: command, @@ -264,7 +260,7 @@ func (e *localExec) transformSecureCopy() error { Time: time.Now(), }, UserMetadata: e.Ctx.Identity.GetUserMetadata(), - ServerMetadata: e.Ctx.GetServer().TargetMetadata(), + ServerMetadata: e.Ctx.GetServer().EventMetadata(), Error: err.Error(), }) return trace.Wrap(err) @@ -369,7 +365,7 @@ func (e *remoteExec) Start(ctx context.Context, ch ssh.Channel) (*ExecResult, er Time: time.Now(), }, UserMetadata: e.ctx.Identity.GetUserMetadata(), - ServerMetadata: e.ctx.GetServer().TargetMetadata(), + ServerMetadata: e.ctx.GetServer().EventMetadata(), Error: err.Error(), }) return nil, trace.Wrap(err) @@ -435,7 +431,7 @@ func (e *remoteExec) PID() int { // instead of ctx.srv. func emitExecAuditEvent(ctx *ServerContext, cmd string, execErr error) { // Create common fields for event. - serverMeta := ctx.GetServer().TargetMetadata() + serverMeta := ctx.GetServer().EventMetadata() sessionMeta := ctx.GetSessionMetadata() userMeta := ctx.Identity.GetUserMetadata() diff --git a/lib/srv/exec_test.go b/lib/srv/exec_test.go index f79242275d1ca..d5b38453e38ef 100644 --- a/lib/srv/exec_test.go +++ b/lib/srv/exec_test.go @@ -64,8 +64,6 @@ func TestEmitExecAuditEvent(t *testing.T) { rec, ok := scx.session.recorder.(*mockRecorder) require.True(t, ok) - scx.GetServer().TargetMetadata() - expectedUsr, err := user.Current() require.NoError(t, err) expectedHostname := "testHost" @@ -151,6 +149,7 @@ func newExecServerContext(t *testing.T, srv Server) *ServerContext { term: term, emitter: rec, recorder: rec, + scx: scx, } err = scx.SetSSHRequest(&ssh.Request{Type: sshutils.ExecRequest}) require.NoError(t, err) diff --git a/lib/srv/forward/sftp.go b/lib/srv/forward/sftp.go index f158ee9f1f853..9e9a60a60bd79 100644 --- a/lib/srv/forward/sftp.go +++ b/lib/srv/forward/sftp.go @@ -94,7 +94,7 @@ func (p *SFTPProxy) Serve() error { Code: events.SFTPSummaryCode, Time: time.Now(), }, - ServerMetadata: scx.GetServer().TargetMetadata(), + ServerMetadata: scx.GetServer().EventMetadata(), SessionMetadata: scx.GetSessionMetadata(), UserMetadata: scx.Identity.GetUserMetadata(), ConnectionMetadata: apievents.ConnectionMetadata{ @@ -230,7 +230,7 @@ func (h *proxyHandlers) sendSFTPEvent(req *sftp.Request, reqErr error) { } else if reqErr != nil { h.logger.DebugContext(req.Context(), "failed handling SFTP request", "request", req.Method, "error", reqErr) } - event.ServerMetadata = h.scx.GetServer().TargetMetadata() + event.ServerMetadata = h.scx.GetServer().EventMetadata() event.SessionMetadata = h.scx.GetSessionMetadata() event.UserMetadata = h.scx.Identity.GetUserMetadata() event.ConnectionMetadata = apievents.ConnectionMetadata{ diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index b4bffc04c99aa..1c0f35fa6c849 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -28,9 +28,9 @@ import ( "net" "os" "strings" + "sync" "time" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" @@ -53,6 +53,7 @@ import ( "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshagent" "github.com/gravitational/teleport/lib/sshutils" @@ -83,8 +84,6 @@ import ( type Server struct { logger *slog.Logger - id string - // targetConn is the TCP connection to the remote host. targetConn net.Conn @@ -155,9 +154,9 @@ type Server struct { clock clockwork.Clock - // hostUUID is the UUID of the underlying proxy that the forwarding server + // proxyUUID is the UUID of the underlying proxy that the forwarding server // is running in. - hostUUID string + proxyUUID string // closeContext and closeCancel are used to signal to the outside // world that this server is closed @@ -174,9 +173,6 @@ type Server struct { // of starting spans. tracerProvider oteltrace.TracerProvider - // TODO(Joerger): Remove in favor of targetServer, which has more accurate values. - targetID, targetAddr, targetHostname string - // targetServer is the host that the connection is being established for. targetServer types.Server } @@ -229,9 +225,9 @@ type ServerConfig struct { // configuration. FIPS bool - // HostUUID is the UUID of the underlying proxy that the forwarding server + // ProxyUUID is the UUID of the underlying proxy that the forwarding server // is running in. - HostUUID string + ProxyUUID string // Emitter is audit events emitter Emitter events.StreamEmitter @@ -247,9 +243,6 @@ type ServerConfig struct { // of starting spans. TracerProvider oteltrace.TracerProvider - // TODO(Joerger): Remove in favor of TargetServer, which has more accurate values. - TargetID, TargetAddr, TargetHostname string - // TargetServer is the host that the connection is being established for. TargetServer types.Server } @@ -331,7 +324,6 @@ func New(c ServerConfig) (*Server, error) { "src_addr", c.SrcAddr.String(), "dst_addr", c.DstAddr.String(), ), - id: uuid.New().String(), targetConn: c.TargetConn, serverConn: utils.NewTrackingConn(serverConn), clientConn: clientConn, @@ -344,14 +336,11 @@ func New(c ServerConfig) (*Server, error) { authService: c.LocalAuthClient, dataDir: c.DataDir, clock: c.Clock, - hostUUID: c.HostUUID, + proxyUUID: c.ProxyUUID, StreamEmitter: c.Emitter, parentContext: c.ParentContext, lockWatcher: c.LockWatcher, tracerProvider: c.TracerProvider, - targetID: c.TargetID, - targetAddr: c.TargetAddr, - targetHostname: c.TargetHostname, targetServer: c.TargetServer, } @@ -397,16 +386,18 @@ func New(c ServerConfig) (*Server, error) { return s, nil } -// TargetMetadata returns metadata about the forwarding target. -func (s *Server) TargetMetadata() apievents.ServerMetadata { +// EventMetadata returns metadata about the forwarding target. +func (s *Server) EventMetadata() apievents.ServerMetadata { + serverInfo := s.GetInfo() return apievents.ServerMetadata{ ServerVersion: teleport.Version, - ServerNamespace: s.GetNamespace(), - ServerID: s.targetID, - ServerAddr: s.targetAddr, - ServerHostname: s.targetHostname, - ForwardedBy: s.hostUUID, - ServerSubKind: s.targetServer.GetSubKind(), + ServerNamespace: serverInfo.GetNamespace(), + ServerID: serverInfo.GetName(), + ServerAddr: serverInfo.GetAddr(), + ServerLabels: serverInfo.GetAllLabels(), + ServerHostname: serverInfo.GetHostname(), + ServerSubKind: serverInfo.GetSubKind(), + ForwardedBy: s.proxyUUID, } } @@ -421,15 +412,15 @@ func (s *Server) GetDataDir() string { return s.dataDir } -// ID returns the ID of the proxy that creates the in-memory forwarding server. +// ID returns the UUID of the server targeted by the forwarding server. func (s *Server) ID() string { - return s.id + return s.targetServer.GetName() } // HostUUID is the UUID of the underlying proxy that the forwarding server // is running in. func (s *Server) HostUUID() string { - return s.hostUUID + return s.proxyUUID } // GetNamespace returns the namespace the forwarding server resides in. @@ -502,19 +493,35 @@ func (s *Server) GetSELinuxEnabled() bool { return false } -// GetInfo returns a services.Server that represents this server. +// GetInfo returns a services.Server that represents the target server. func (s *Server) GetInfo() types.Server { - return &types.ServerV2{ + // Only set the address for non-tunnel nodes. + var addr string + if !s.targetServer.GetUseTunnel() { + addr = s.targetServer.GetAddr() + } + + srv := &types.ServerV2{ Kind: types.KindNode, + SubKind: s.targetServer.GetSubKind(), Version: types.V2, Metadata: types.Metadata{ - Name: s.ID(), - Namespace: s.GetNamespace(), + Name: s.targetServer.GetName(), + Namespace: s.targetServer.GetNamespace(), + Labels: s.targetServer.GetLabels(), }, Spec: types.ServerSpecV2{ - Addr: s.AdvertiseAddr(), + CmdLabels: types.LabelsToV2(s.targetServer.GetCmdLabels()), + Addr: addr, + Hostname: s.targetServer.GetHostname(), + UseTunnel: s.useTunnel, + Version: teleport.Version, + ProxyIDs: s.targetServer.GetProxyIDs(), + PublicAddrs: s.targetServer.GetPublicAddrs(), }, } + + return srv } // Dial returns the client connection created by pipeAddrConn. @@ -557,6 +564,10 @@ func (s *Server) Serve() { config.KeyExchanges = s.kexAlgorithms config.MACs = s.macAlgorithms + // Set the server version to Teleport to enable tracing and other Teleport + // specific features like joining. + config.ServerVersion = sshutils.SSHVersionPrefix + netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(s.Context()) if err != nil { s.logger.ErrorContext(s.Context(), "Unable to fetch cluster config", "error", err) @@ -907,7 +918,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha // Create context for this channel. This context will be closed when // forwarding is complete. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, nil) if err != nil { if err := nch.Reject(ssh.ConnectionFailed, "failed to open server context"); err != nil { s.logger.ErrorContext(ctx, "Error rejecting forwarded-tcpip channel", "error", err) @@ -978,8 +989,23 @@ func (s *Server) handleGlobalRequest(ctx context.Context, req *ssh.Request) { } // Pass request on unchanged. case teleport.SessionIDQueryRequest: + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. + // Reply true to session ID query requests, we will set new - // session IDs for new sessions + // session IDs for new sessions during the shel/exec channel + // request. + if err := req.Reply(true, nil); err != nil { + s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) + } + return + case teleport.SessionIDQueryRequestV2: + // TODO(Joerger): DELETE IN v21.0.0 + // clients should stop checking in v21, and servers should stop responding to the query in v22. + + // Reply true to session ID query requests, we will set new + // session IDs for new sessions directly after accepting the + // session channel request. if err := req.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) } @@ -1017,7 +1043,7 @@ func (s *Server) checkTCPIPForwardRequest(ctx context.Context, r *ssh.Request) e // RBAC checks are only necessary when connecting to an agentless node if s.targetServer.IsOpenSSHNode() { - scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext, nil) if err != nil { return err } @@ -1080,7 +1106,7 @@ func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) { func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, nil) if err != nil { s.logger.ErrorContext(ctx, "Unable to create connection context", "error", err) s.stderrWrite(ctx, ch, "Unable to create connection context.") @@ -1132,12 +1158,22 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r // the remote host. Once the session channel has been established, this function's loop handles // all the "exec", "subsystem" and "shell" requests. func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { + // sessionParams will not be passed by old clients (< v19) or OpenSSH clients. + sessionParams, err := tracessh.ParseSessionParams(nch.ExtraData()) + if err != nil { + s.logger.ErrorContext(ctx, "Failed to parse request data", "data", string(nch.ExtraData()), "error", err) + if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + // Create context for this channel. This context will be closed when the // session request is complete. // There is no need for the forwarding server to initiate disconnects, // based on teleport business logic, because this logic is already // done on the server's terminating side. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, sessionParams) if err != nil { s.logger.WarnContext(ctx, "Server context setup failed", "error", err) if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("server context setup failed: %v", err)); err != nil { @@ -1154,11 +1190,52 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { scx.SetAllowFileCopying(true) defer scx.Close() + // If this is a Teleport node server, it should send the session ID + // right after the session channel is accepted. We should reuse this + // session ID and delegate session responsibilities (recordings, audit + // events, and session trackers) to avoid duplicates. + // + // Register handler to receive the current session ID before starting the session. + var newSessionIDFromServer chan string + if s.targetServer.GetSubKind() == types.SubKindTeleportNode { + // Check if the Teleport Node is outdated and won't actually send the session ID. + // + // TODO(Joerger): DELETE IN v20.0.0 + // all v19+ servers set and share the session ID directly after accepting the session channel. + // clients should stop checking in v21, and servers should stop responding to the query in v22. + reply, payload, err := s.remoteClient.SendRequest(ctx, teleport.SessionIDQueryRequestV2, true, nil) + if err != nil { + s.logger.WarnContext(ctx, "Failed to send session ID query request", "error", err) + } else if !reply && payload != nil { + // If the target node replies with a payload, this means that the connection itself has been rejected, + // presumably due to an authz error, and the server is trying to communicate the error with the first + // req/chan received. + s.logger.WarnContext(ctx, "Remote session open failed", "error", err) + if err := nch.Reject(ssh.Prohibited, fmt.Sprintf("remote session open failed: %v", string(payload))); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + + if err == nil && reply { + newSessionIDFromServer = make(chan string, 1) + var receiveSessionIDOnce sync.Once + s.remoteClient.HandleSessionRequest(ctx, teleport.CurrentSessionIDRequest, func(ctx context.Context, req *ssh.Request) { + // Only handle the first request - only one is expected. + receiveSessionIDOnce.Do(func() { + newSessionIDFromServer <- string(req.Payload) + }) + }) + } else { + s.logger.WarnContext(ctx, "Failed to query session ID from target node. Ensure the targeted Teleport Node is upgraded to v19.0.0+ to avoid duplicate events due to mismatched session IDs.") + } + } + // Create a "session" channel on the remote host. Note that we // create the remote session channel before accepting the local // channel request; this allows us to propagate the rejection // reason/message in the event the channel is rejected. - remoteSession, err := s.remoteClient.NewSession(ctx) + remoteSession, err := s.remoteClient.NewSessionWithParams(ctx, sessionParams) if err != nil { s.logger.WarnContext(ctx, "Remote session open failed", "error", err) reason, msg := ssh.ConnectionFailed, fmt.Sprintf("remote session open failed: %v", err) @@ -1173,6 +1250,38 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { } scx.RemoteSession = remoteSession + if newSessionIDFromServer != nil { + // Wait for the session ID to be reported by the target node. + select { + case sidString := <-newSessionIDFromServer: + sid, err := session.ParseID(sidString) + if err != nil { + s.logger.WarnContext(ctx, "Unable to parse session ID reported by target Teleport Node", "error", err) + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + scx.SetNewSessionID(ctx, *sid) + case <-time.After(10 * time.Second): + s.logger.WarnContext(ctx, "Failed to receive session ID from target node. Ensure the targeted Teleport Node is upgraded to v19.0.0+ to avoid duplicate events due to mismatched session IDs.") + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + case <-ctx.Done(): + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + } else { + // The target node is not expected to report session ID, either because it's + // outdated or an agentless node. Continue with a random session ID and ensure + // we create a new session tracker. + scx.SetNewSessionID(ctx, session.NewID()) + } + // Accept the session channel request ch, in, err := nch.Accept() if err != nil { @@ -1183,9 +1292,19 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { return } scx.AddCloser(ch) - ch = scx.TrackActivity(ch) + // inform the client of the session ID that is going to be used in a new + // goroutine to reduce latency. + go func() { + sid := scx.GetNewSessionID() + s.logger.DebugContext(ctx, "Sending current session ID", "sid", sid) + _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sid)) + if err != nil { + s.logger.DebugContext(ctx, "Failed to send the current session ID", "error", err) + } + }() + s.logger.DebugContext(ctx, "Opening session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID()) defer s.logger.DebugContext(ctx, "Closing session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID()) @@ -1470,7 +1589,7 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R Time: time.Now(), }, UserMetadata: serverContext.Identity.GetUserMetadata(), - ServerMetadata: serverContext.GetServer().TargetMetadata(), + ServerMetadata: serverContext.GetServer().EventMetadata(), Error: err.Error(), }) return trace.Wrap(err) diff --git a/lib/srv/forward/sshserver_test.go b/lib/srv/forward/sshserver_test.go index e46201010d898..cac12932e1668 100644 --- a/lib/srv/forward/sshserver_test.go +++ b/lib/srv/forward/sshserver_test.go @@ -28,12 +28,15 @@ import ( "sync/atomic" "testing" + "github.com/google/uuid" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/authclient" @@ -380,3 +383,103 @@ func TestServerConfigCheckDefaults(t *testing.T) { }) } } + +func TestEventMetadata(t *testing.T) { + nodeID := uuid.NewString() + proxyID := uuid.NewString() + + for _, tt := range []struct { + name string + subkind string + spec types.ServerSpecV2 + labels map[string]string + expectMetadata events.ServerMetadata + }{ + { + name: "tunnel node", + labels: map[string]string{ + "stcLabel": "stcResult", + }, + spec: types.ServerSpecV2{ + Addr: "127.0.0.1:3022", + CmdLabels: map[string]types.CommandLabelV2{ + "cmdLabel": {Result: "cmdResult"}, + }, + Hostname: "server01", + UseTunnel: true, + }, + expectMetadata: events.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: nodeID, + ServerNamespace: apidefaults.Namespace, + ServerAddr: "", + ServerHostname: "server01", + ServerLabels: map[string]string{ + "stcLabel": "stcResult", + "cmdLabel": "cmdResult", + }, + ServerSubKind: types.SubKindTeleportNode, + ForwardedBy: proxyID, + }, + }, { + name: "tunnel node", + labels: map[string]string{ + "stcLabel": "stcResult", + }, + spec: types.ServerSpecV2{ + Addr: "127.0.0.1:3022", + CmdLabels: map[string]types.CommandLabelV2{ + "cmdLabel": {Result: "cmdResult"}, + }, + Hostname: "server01", + }, + expectMetadata: events.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: nodeID, + ServerNamespace: apidefaults.Namespace, + ServerAddr: "127.0.0.1:3022", + ServerHostname: "server01", + ServerLabels: map[string]string{ + "stcLabel": "stcResult", + "cmdLabel": "cmdResult", + }, + ServerSubKind: types.SubKindTeleportNode, + ForwardedBy: proxyID, + }, + }, { + name: "agentless node", + subkind: types.SubKindOpenSSHNode, + labels: map[string]string{ + "stcLabel": "stcResult", + }, + spec: types.ServerSpecV2{ + Addr: "openssh.example.com:22", + Hostname: "agentless-host", + }, + expectMetadata: events.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: nodeID, + ServerNamespace: apidefaults.Namespace, + ServerAddr: "openssh.example.com:22", + ServerHostname: "agentless-host", + ServerLabels: map[string]string{ + "stcLabel": "stcResult", + }, + ServerSubKind: types.SubKindOpenSSHNode, + ForwardedBy: proxyID, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + targetServer, err := types.NewNode(nodeID, tt.subkind, tt.spec, tt.labels) + require.NoError(t, err) + + forwardSrv := &Server{ + proxyUUID: proxyID, + targetServer: targetServer, + } + + require.EqualValues(t, tt.expectMetadata, forwardSrv.EventMetadata()) + }) + } +} diff --git a/lib/srv/forward/subsystem.go b/lib/srv/forward/subsystem.go index fbcb4957183e4..0c76d3f2a2a25 100644 --- a/lib/srv/forward/subsystem.go +++ b/lib/srv/forward/subsystem.go @@ -164,7 +164,7 @@ func (r *remoteSubsystem) emitAuditEvent(ctx context.Context, err error) { RemoteAddr: r.serverContext.RemoteClient.RemoteAddr().String(), }, Name: r.subsystemName, - ServerMetadata: r.serverContext.GetServer().TargetMetadata(), + ServerMetadata: r.serverContext.GetServer().EventMetadata(), } if err != nil { diff --git a/lib/srv/git/forward.go b/lib/srv/git/forward.go index 1427e7173e52c..5383dba230fcb 100644 --- a/lib/srv/git/forward.go +++ b/lib/srv/git/forward.go @@ -374,7 +374,7 @@ func (s *ForwardServer) onConnection(ctx context.Context, ccx *sshutils.Connecti // TODO(greedy52) decouple from srv.NewServerContext. We only need // connection monitoring. - serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx) + serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx, nil) if err != nil { return nil, trace.Wrap(err) } @@ -400,11 +400,18 @@ func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionC return } + // sessionParams will not be passed by old clients (< v19) or OpenSSH clients. + sessionParams, err := tracessh.ParseSessionParams(nch.ExtraData()) + if err != nil { + s.reply.RejectWithAcceptError(ctx, nch, err) + return + } + if s.remoteClient == nil { s.reply.RejectWithNewRemoteSessionError(ctx, nch, trace.NotFound("missing remote client")) return } - remoteSession, err := s.remoteClient.NewSession(ctx) + remoteSession, err := s.remoteClient.NewSessionWithParams(ctx, sessionParams) if err != nil { s.reply.RejectWithNewRemoteSessionError(ctx, nch, err) return @@ -566,7 +573,7 @@ func (s *ForwardServer) makeGitCommandEvent(sctx *sessionContext, command string RemoteAddr: sctx.ServerConn.RemoteAddr().String(), LocalAddr: sctx.ServerConn.LocalAddr().String(), }, - ServerMetadata: s.TargetMetadata(), + ServerMetadata: s.EventMetadata(), } if err != nil { event.Metadata.Code = events.GitCommandFailureCode @@ -663,7 +670,7 @@ func makeRemoteSigner(ctx context.Context, cfg *ForwardServerConfig, identityCtx func (s *ForwardServer) Context() context.Context { return s.cfg.ParentContext } -func (s *ForwardServer) TargetMetadata() apievents.ServerMetadata { +func (s *ForwardServer) EventMetadata() apievents.ServerMetadata { return apievents.ServerMetadata{ ServerVersion: teleport.Version, ServerNamespace: s.cfg.TargetServer.GetNamespace(), diff --git a/lib/srv/mock_test.go b/lib/srv/mock_test.go index 61546e42e783a..a4ea1fb6a98c5 100644 --- a/lib/srv/mock_test.go +++ b/lib/srv/mock_test.go @@ -47,6 +47,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -74,6 +75,7 @@ func newTestServerContext(t *testing.T, srv Server, sessionJoiningRoleSet servic clusterName := "localhost" _, connCtx := sshutils.NewConnectionContext(ctx, nil, &ssh.ServerConn{Conn: sshConn}) scx := &ServerContext{ + newSessionID: rsession.NewID(), Logger: logtest.NewLogger(), ConnectionContext: connCtx, env: make(map[string]string), @@ -93,6 +95,9 @@ func newTestServerContext(t *testing.T, srv Server, sessionJoiningRoleSet servic }, cancelContext: ctx, cancel: cancel, + // If proxy forwarding is being used (proxy recording, agentless), then remote session must be set. + // Otherwise, this field is ignored. + RemoteSession: mockSSHSession(t), } err = scx.SetExecRequest(&localExec{Ctx: scx}) @@ -161,6 +166,7 @@ func newMockServer(t *testing.T) *mockServer { datadir: t.TempDir(), MockRecorderEmitter: &eventstest.MockRecorderEmitter{}, clock: clock, + component: teleport.ComponentNode, } } @@ -253,7 +259,7 @@ func (m *mockServer) GetInfo() types.Server { } } -func (m *mockServer) TargetMetadata() apievents.ServerMetadata { +func (m *mockServer) EventMetadata() apievents.ServerMetadata { return apievents.ServerMetadata{ ServerID: "123", ForwardedBy: "abc", diff --git a/lib/srv/regular/sftp.go b/lib/srv/regular/sftp.go index b58efa842999d..c47278b382872 100644 --- a/lib/srv/regular/sftp.go +++ b/lib/srv/regular/sftp.go @@ -77,7 +77,7 @@ func (s *sftpSubsys) Start(ctx context.Context, Time: time.Now(), }, UserMetadata: serverCtx.Identity.GetUserMetadata(), - ServerMetadata: serverCtx.GetServer().TargetMetadata(), + ServerMetadata: serverCtx.GetServer().EventMetadata(), Error: srv.ErrNodeFileCopyingNotPermitted.Error(), }) return srv.ErrNodeFileCopyingNotPermitted @@ -168,7 +168,7 @@ func (s *sftpSubsys) Start(ctx context.Context, defer auditPipeOut.Close() // Create common fields for events - serverMeta := serverCtx.GetServer().TargetMetadata() + serverMeta := serverCtx.GetServer().EventMetadata() sessionMeta := serverCtx.GetSessionMetadata() userMeta := serverCtx.Identity.GetUserMetadata() connectionMeta := apievents.ConnectionMetadata{ diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 84910ab87c7df..b4a6c105d4c69 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -67,6 +67,7 @@ import ( authorizedkeysreporter "github.com/gravitational/teleport/lib/secretsscanner/authorizedkeys" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/sshagent" @@ -258,15 +259,17 @@ type Server struct { scope string } -// TargetMetadata returns metadata about the server. -func (s *Server) TargetMetadata() apievents.ServerMetadata { +// EventMetadata returns metadata about the server. +func (s *Server) EventMetadata() apievents.ServerMetadata { + serverInfo := s.GetInfo() return apievents.ServerMetadata{ ServerVersion: teleport.Version, - ServerNamespace: s.GetNamespace(), - ServerID: s.ID(), - ServerAddr: s.Addr(), - ServerLabels: s.getAllLabels(), - ServerHostname: s.hostname, + ServerNamespace: serverInfo.GetNamespace(), + ServerID: serverInfo.GetName(), + ServerAddr: serverInfo.GetAddr(), + ServerLabels: serverInfo.GetAllLabels(), + ServerHostname: serverInfo.GetHostname(), + ServerSubKind: serverInfo.GetSubKind(), } } @@ -1110,18 +1113,6 @@ func (s *Server) getDynamicLabels() map[string]types.CommandLabelV2 { return types.LabelsToV2(s.dynamicLabels.Get()) } -// getAllLabels return a combination of static and dynamic labels. -func (s *Server) getAllLabels() map[string]string { - lmap := make(map[string]string) - for key, value := range s.getStaticLabels() { - lmap[key] = value - } - for key, cmd := range s.getDynamicLabels() { - lmap[key] = cmd.Result - } - return lmap -} - // GetInfo returns a services.Server that represents this server. func (s *Server) GetInfo() types.Server { return s.getBasicInfo() @@ -1248,7 +1239,7 @@ func (s *Server) getNetworkingProcess(scx *srv.ServerContext) (*networking.Proce // the server connection is closed. func (s *Server) startNetworkingProcess(scx *srv.ServerContext) (*networking.Process, error) { // Create context for the networking process. - nsctx, err := srv.NewServerContext(context.Background(), scx.ConnectionContext, s, scx.Identity) + nsctx, err := srv.NewServerContext(context.Background(), scx.ConnectionContext, s, scx.Identity, nil) if err != nil { return nil, trace.Wrap(err) } @@ -1307,8 +1298,23 @@ func (s *Server) HandleRequest(ctx context.Context, ccx *sshutils.ConnectionCont } } case teleport.SessionIDQueryRequest: + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. + // Reply true to session ID query requests, we will set new - // session IDs for new sessions + // session IDs for new sessions during the shel/exec channel + // request. + if err := r.Reply(true, nil); err != nil { + s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) + } + return + case teleport.SessionIDQueryRequestV2: + // TODO(Joerger): DELETE IN v21.0.0 + // clients should stop checking in v21, and servers should stop responding to the query in v22. + + // Reply true to session ID query requests, we will set new + // session IDs for new sessions directly after accepting the + // session channel request. if err := r.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) } @@ -1433,7 +1439,7 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont s.rejectChannel(ctx, nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(ctx, ccx, identityContext, ch, requests) + go s.handleSessionRequests(ctx, ccx, identityContext, nil, ch, requests) return default: s.rejectChannel(ctx, nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) @@ -1476,6 +1482,15 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont } decr = d } + + // SessionParams are not passed by old clients (