diff --git a/constants.go b/constants.go index 47b4defc1846f..c52e95ad3df36 100644 --- a/constants.go +++ b/constants.go @@ -831,9 +831,22 @@ const ( CurrentSessionIDRequest = "current-session-id@goteleport.com" // SessionIDQueryRequest is sent by clients to ask servers if they - // will generate their own session ID when a new session is created. + // will generate and share their own session ID when a new session + // is started (session and exec/shell channels accepted). + // + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. SessionIDQueryRequest = "session-id-query@goteleport.com" + // SessionIDQueryRequestV2 is sent by clients to ask servers if they + // will generate and share their own session ID when a new session + // channel is accepted, rather than when the shell/exec channel is. + // + // TODO(Joerger): DELETE IN v21.0.0 + // all v19+ servers set the session ID directly after accepting the session channel. + // clients should stop checking in v21, and servers should stop responding to the query in v22. + SessionIDQueryRequestV2 = "session-id-query-v2@goteleport.com" + // ForceTerminateRequest is an SSH request to forcefully terminate a session. ForceTerminateRequest = "x-teleport-force-terminate" diff --git a/integration/integration_test.go b/integration/integration_test.go index 61cbd69254b18..7c5bc248ed1b6 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -531,6 +531,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } } + // Test streaming events and recording. capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID) findByType := func(et string) apievents.AuditEvent { @@ -541,19 +542,6 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } return nil } - // helper that asserts that a session event is also included in the - // general audit log. - requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) { - t.Helper() - auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ - To: time.Now(), - EventTypes: []string{sessionEvent.GetType()}, - }) - require.NoError(t, err) - require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { - return ae.GetID() == sessionEvent.GetID() - })) - } // there should always be 'session.start' event (and it must be first) first := sessionEvents[0].(*apievents.SessionStart) @@ -561,19 +549,16 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { require.Equal(t, first, start) require.Equal(t, sessionID, start.SessionID) require.NotEmpty(t, start.TerminalSize) - requireInAuditLog(t, start) // there should always be 'session.end' event end := findByType(events.SessionEndEvent).(*apievents.SessionEnd) require.NotNil(t, end) require.Equal(t, sessionID, end.SessionID) - requireInAuditLog(t, end) // there should always be 'session.leave' event leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave) require.NotNil(t, leave) require.Equal(t, sessionID, leave.SessionID) - requireInAuditLog(t, leave) // all of them should have a proper time for _, e := range sessionEvents { @@ -584,6 +569,31 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { recorded := replaceNewlines(capturedStream) require.Regexp(t, ".*exit.*", recorded) require.Regexp(t, ".*echo hi.*", recorded) + + sessionEvents, _, err = site.SearchEvents(ctx, events.SearchEventsRequest{ + From: time.Time{}, + To: time.Now(), + EventTypes: []string{ + events.SessionStartEvent, + events.SessionLeaveEvent, + events.SessionEndEvent, + }, + }) + require.NoError(t, err) + + // Check that the events found above in the session stream show up in the backend. + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == start.GetID() + }), "expected session events to contain session.start event") + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == end.GetID() + }), "expected session events to contain session.end event") + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == leave.GetID() + }), "expected session events to contain session.leave event") + + // Ensure there are no duplicate events, e.g. from proxy recording mode. + require.Len(t, sessionEvents, 3, "%d unexpected duplicate events", len(sessionEvents)-4) }) } } diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index ed85312232aef..6431fa246edba 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -52,6 +52,7 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" @@ -319,6 +320,12 @@ type ServerContext struct { // sessionParams are parameters associated with this server session. sessionParams *tracessh.SessionParams + // newSessionID is set if this server context is going to create a new session. + // This field must be set through [ServerContext.SetNewSessionID] for non-join + // sessions as soon as a session channel is accepted in order to inform + // the client of the to-be session ID. + newSessionID rsession.ID + // session holds the active session (if there's an active one). session *session @@ -692,21 +699,25 @@ func (c *ServerContext) GetSessionParams() tracessh.SessionParams { return sessionParams } +// SetNewSessionID sets the ID for a new session in this server context. +func (c *ServerContext) SetNewSessionID(ctx context.Context, sid rsession.ID) { + c.mu.Lock() + defer c.mu.Unlock() + c.newSessionID = sid +} + +// GetNewSessionID gets the ID for a new session in this server context. +func (c *ServerContext) GetNewSessionID() rsession.ID { + c.mu.Lock() + defer c.mu.Unlock() + return c.newSessionID +} + // setSession sets the context's session -func (c *ServerContext) setSession(ctx context.Context, sess *session, ch ssh.Channel) { +func (c *ServerContext) setSession(ctx context.Context, sess *session) { c.mu.Lock() defer c.mu.Unlock() c.session = sess - - // inform the client of the session ID that is being used in a new - // goroutine to reduce latency - go func() { - c.Logger.DebugContext(ctx, "Sending current session ID") - _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID())) - if err != nil { - c.Logger.DebugContext(ctx, "Failed to send the current session ID", "error", err) - } - }() } // getSession returns the context's session @@ -922,6 +933,15 @@ func (c *ServerContext) reportStats(conn *utils.TrackingConn) { serverRX.Add(float64(rxBytes)) } +// ShouldHandleRecording returns whether this server context is responsible for +// recording session events, including session recording, audit events, and session tracking. +func (c *ServerContext) ShouldHandleSessionRecording() bool { + // The only time this server is not responsible for recording the session is when this + // is a Teleport Node with Proxy recording mode turned on, where the forwarding node will + // handle the recording. + return c.srv.Component() != teleport.ComponentNode || !services.IsRecordAtProxy(c.SessionRecordingConfig.GetMode()) +} + func (c *ServerContext) Close() error { // If the underlying connection is holding tracking information, report that // to the audit log at close. diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 8daf6c7c1c31f..2fab90d5a1a1c 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -39,10 +39,8 @@ import ( "github.com/gravitational/teleport" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" - "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -102,10 +100,8 @@ func NewExecRequest(ctx *ServerContext, command string) (Exec, error) { }, nil } - // If this is a registered OpenSSH node or proxy recoding mode is - // enabled, execute the command on a remote host. This is used by - // in-memory forwarding nodes. - if types.IsOpenSSHNodeSubKind(ctx.ServerSubKind) || services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { + // If this is a forwarding node, execute the command on a remote host. + if ctx.srv.Component() == teleport.ComponentForwardingNode { return &remoteExec{ ctx: ctx, command: command, diff --git a/lib/srv/exec_test.go b/lib/srv/exec_test.go index 086a43e7985a2..d5b38453e38ef 100644 --- a/lib/srv/exec_test.go +++ b/lib/srv/exec_test.go @@ -149,6 +149,7 @@ func newExecServerContext(t *testing.T, srv Server) *ServerContext { term: term, emitter: rec, recorder: rec, + scx: scx, } err = scx.SetSSHRequest(&ssh.Request{Type: sshutils.ExecRequest}) require.NoError(t, err) diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 20da5ea0b1b40..14b965299be8b 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -28,6 +28,7 @@ import ( "net" "os" "strings" + "sync" "time" "github.com/gravitational/trace" @@ -51,6 +52,7 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshagent" "github.com/gravitational/teleport/lib/sshutils" @@ -967,8 +969,23 @@ func (s *Server) handleGlobalRequest(ctx context.Context, req *ssh.Request) { } // Pass request on unchanged. case teleport.SessionIDQueryRequest: + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. + // Reply true to session ID query requests, we will set new - // session IDs for new sessions + // session IDs for new sessions during the shel/exec channel + // request. + if err := req.Reply(true, nil); err != nil { + s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) + } + return + case teleport.SessionIDQueryRequestV2: + // TODO(Joerger): DELETE IN v21.0.0 + // clients should stop checking in v21, and servers should stop responding to the query in v22. + + // Reply true to session ID query requests, we will set new + // session IDs for new sessions directly after accepting the + // session channel request. if err := req.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) } @@ -1153,6 +1170,47 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { scx.SetAllowFileCopying(true) defer scx.Close() + // If this is a Teleport node server, it should send the session ID + // right after the session channel is accepted. We should reuse this + // session ID and delegate session responsibilities (recordings, audit + // events, and session trackers) to avoid duplicates. + // + // Register handler to receive the current session ID before starting the session. + var newSessionIDFromServer chan string + if s.targetServer.GetSubKind() == types.SubKindTeleportNode { + // Check if the Teleport Node is outdated and won't actually send the session ID. + // + // TODO(Joerger): DELETE IN v20.0.0 + // all v19+ servers set and share the session ID directly after accepting the session channel. + // clients should stop checking in v21, and servers should stop responding to the query in v22. + reply, payload, err := s.remoteClient.SendRequest(ctx, teleport.SessionIDQueryRequestV2, true, nil) + if err != nil { + s.logger.WarnContext(ctx, "Failed to send session ID query request", "error", err) + } else if !reply && payload != nil { + // If the target node replies with a payload, this means that the connection itself has been rejected, + // presumably due to an authz error, and the server is trying to communicate the error with the first + // req/chan received. + s.logger.WarnContext(ctx, "Remote session open failed", "error", err) + if err := nch.Reject(ssh.Prohibited, fmt.Sprintf("remote session open failed: %v", string(payload))); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + + if err == nil && reply { + newSessionIDFromServer = make(chan string, 1) + var receiveSessionIDOnce sync.Once + s.remoteClient.HandleSessionRequest(ctx, teleport.CurrentSessionIDRequest, func(ctx context.Context, req *ssh.Request) { + // Only handle the first request - only one is expected. + receiveSessionIDOnce.Do(func() { + newSessionIDFromServer <- string(req.Payload) + }) + }) + } else { + s.logger.WarnContext(ctx, "Failed to query session ID from target node. Ensure the targeted Teleport Node is upgraded to v19.0.0+ to avoid duplicate events due to mismatched session IDs.") + } + } + // Create a "session" channel on the remote host. Note that we // create the remote session channel before accepting the local // channel request; this allows us to propagate the rejection @@ -1172,6 +1230,38 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { } scx.RemoteSession = remoteSession + if newSessionIDFromServer != nil { + // Wait for the session ID to be reported by the target node. + select { + case sidString := <-newSessionIDFromServer: + sid, err := session.ParseID(sidString) + if err != nil { + s.logger.WarnContext(ctx, "Unable to parse session ID reported by target Teleport Node", "error", err) + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + scx.SetNewSessionID(ctx, *sid) + case <-time.After(10 * time.Second): + s.logger.WarnContext(ctx, "Failed to receive session ID from target node. Ensure the targeted Teleport Node is upgraded to v19.0.0+ to avoid duplicate events due to mismatched session IDs.") + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + case <-ctx.Done(): + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + } else { + // The target node is not expected to report session ID, either because it's + // outdated or an agentless node. Continue with a random session ID and ensure + // we create a new session tracker. + scx.SetNewSessionID(ctx, session.NewID()) + } + // Accept the session channel request ch, in, err := nch.Accept() if err != nil { @@ -1182,9 +1272,19 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { return } scx.AddCloser(ch) - ch = scx.TrackActivity(ch) + // inform the client of the session ID that is going to be used in a new + // goroutine to reduce latency. + go func() { + sid := scx.GetNewSessionID() + s.logger.DebugContext(ctx, "Sending current session ID", "sid", sid) + _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sid)) + if err != nil { + s.logger.DebugContext(ctx, "Failed to send the current session ID", "error", err) + } + }() + s.logger.DebugContext(ctx, "Opening session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID()) defer s.logger.DebugContext(ctx, "Closing session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID()) diff --git a/lib/srv/mock_test.go b/lib/srv/mock_test.go index ab5e7c2f49594..5b48830edc04b 100644 --- a/lib/srv/mock_test.go +++ b/lib/srv/mock_test.go @@ -47,6 +47,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -74,6 +75,7 @@ func newTestServerContext(t *testing.T, srv Server, sessionJoiningRoleSet servic clusterName := "localhost" _, connCtx := sshutils.NewConnectionContext(ctx, nil, &ssh.ServerConn{Conn: sshConn}) scx := &ServerContext{ + newSessionID: rsession.NewID(), Logger: logtest.NewLogger(), ConnectionContext: connCtx, env: make(map[string]string), @@ -93,6 +95,9 @@ func newTestServerContext(t *testing.T, srv Server, sessionJoiningRoleSet servic }, cancelContext: ctx, cancel: cancel, + // If proxy forwarding is being used (proxy recording, agentless), then remote session must be set. + // Otherwise, this field is ignored. + RemoteSession: mockSSHSession(t), } err = scx.SetExecRequest(&localExec{Ctx: scx}) @@ -160,6 +165,7 @@ func newMockServer(t *testing.T) *mockServer { datadir: t.TempDir(), MockRecorderEmitter: &eventstest.MockRecorderEmitter{}, clock: clock, + component: teleport.ComponentNode, } } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 106297cd48a32..7b48bd44a8ee1 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -66,6 +66,7 @@ import ( authorizedkeysreporter "github.com/gravitational/teleport/lib/secretsscanner/authorizedkeys" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/sshagent" @@ -1296,8 +1297,23 @@ func (s *Server) HandleRequest(ctx context.Context, ccx *sshutils.ConnectionCont } } case teleport.SessionIDQueryRequest: + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. + + // Reply true to session ID query requests, we will set new + // session IDs for new sessions during the shel/exec channel + // request. + if err := r.Reply(true, nil); err != nil { + s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) + } + return + case teleport.SessionIDQueryRequestV2: + // TODO(Joerger): DELETE IN v21.0.0 + // clients should stop checking in v21, and servers should stop responding to the query in v22. + // Reply true to session ID query requests, we will set new - // session IDs for new sessions + // session IDs for new sessions directly after accepting the + // session channel request. if err := r.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) } @@ -1646,6 +1662,27 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec trackingChan := scx.TrackActivity(ch) + // If we are creating a new session (not joining a session), prepare a new session + // ID and inform the client. + // + // Note: If this is an old client (