From eb5ee3a0400e7f400fb6df3d3121d52b89054712 Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Tue, 23 Sep 2025 16:45:20 -0700 Subject: [PATCH 1/4] Fix discrepancies between Node and Proxy recording modes. (#58707) --- integration/integration_test.go | 82 ++++++++++++++++++++++++++++++- lib/reversetunnel/localsite.go | 5 +- lib/reversetunnel/remotesite.go | 5 +- lib/srv/authhandlers.go | 2 +- lib/srv/ctx.go | 9 ++-- lib/srv/exec.go | 6 +-- lib/srv/exec_test.go | 2 - lib/srv/forward/sftp.go | 4 +- lib/srv/forward/sshserver.go | 79 ++++++++++++++++------------- lib/srv/forward/subsystem.go | 2 +- lib/srv/git/forward.go | 4 +- lib/srv/mock_test.go | 2 +- lib/srv/regular/sftp.go | 4 +- lib/srv/regular/sshserver.go | 30 ++++------- lib/srv/regular/sshserver_test.go | 2 +- lib/srv/sess.go | 10 ++-- lib/web/apiserver_test.go | 4 +- 17 files changed, 161 insertions(+), 91 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index 81735649ee648..86265202c6467 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -149,8 +149,8 @@ func TestIntegrations(t *testing.T) { t.Run("BPFSessionDifferentiation", suite.bind(testBPFSessionDifferentiation)) t.Run("ClientIdleConnection", suite.bind(testClientIdleConnection)) t.Run("CmdLabels", suite.bind(testCmdLabels)) + t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters)) t.Run("ControlMaster", suite.bind(testControlMaster)) - t.Run("X11Forwarding", suite.bind(testX11Forwarding)) t.Run("CustomReverseTunnel", suite.bind(testCustomReverseTunnel)) t.Run("DataTransfer", suite.bind(testDataTransfer)) t.Run("DifferentPinnedIP", suite.bind(testDifferentPinnedIP)) @@ -183,6 +183,7 @@ func TestIntegrations(t *testing.T) { t.Run("PAM", suite.bind(testPAM)) t.Run("PortForwarding", suite.bind(testPortForwarding)) t.Run("ProxyHostKeyCheck", suite.bind(testProxyHostKeyCheck)) + t.Run("RecordingModesSessionTrackers", suite.bind(testRecordingModesSessionTrackers)) t.Run("ReverseTunnelCollapse", suite.bind(testReverseTunnelCollapse)) t.Run("RotateRollback", suite.bind(testRotateRollback)) t.Run("RotateSuccess", suite.bind(testRotateSuccess)) @@ -200,12 +201,12 @@ func TestIntegrations(t *testing.T) { t.Run("TrustedClustersRoleMapChanges", suite.bind(testTrustedClustersRoleMapChanges)) t.Run("TrustedClustersWithLabels", suite.bind(testTrustedClustersWithLabels)) t.Run("TrustedClustersSkipNameValidation", suite.bind(testTrustedClustersSkipNameValidation)) - t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters)) t.Run("TrustedTunnelNode", suite.bind(testTrustedTunnelNode)) t.Run("TwoClustersProxy", suite.bind(testTwoClustersProxy)) t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel)) t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy)) t.Run("WindowChange", suite.bind(testWindowChange)) + t.Run("X11Forwarding", suite.bind(testX11Forwarding)) } // testDifferentPinnedIP tests connection is rejected when source IP doesn't match the pinned one @@ -1027,6 +1028,83 @@ func testSessionRecordingModes(t *testing.T, suite *integrationTestSuite) { } } +func testRecordingModesSessionTrackers(t *testing.T, suite *integrationTestSuite) { + ctx := t.Context() + + cfg := suite.defaultServiceConfig() + cfg.Auth.Enabled = true + cfg.Proxy.DisableWebService = true + cfg.Proxy.DisableWebInterface = true + cfg.Proxy.Enabled = true + cfg.SSH.Enabled = true + + teleport := suite.NewTeleportWithConfig(t, nil, nil, cfg) + defer teleport.StopAll() + + // startSession starts an interactive session, users must terminate the + // session by typing "exit" in the terminal. + startSession := func(username string) (*Terminal, chan error) { + term := NewTerminal(250) + errCh := make(chan error) + + go func() { + cl, err := teleport.NewClient(helpers.ClientConfig{ + Login: username, + Cluster: helpers.Site, + Host: Host, + }) + if err != nil { + errCh <- trace.Wrap(err) + return + } + cl.Stdout = term + cl.Stdin = term + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + errCh <- cl.SSH(ctx, nil) + }() + + return term, errCh + } + + err := teleport.WaitForNodeCount(ctx, helpers.Site, 1) + require.NoError(t, err) + + auth := teleport.Process.GetAuthServer() + for _, mode := range []string{types.RecordAtNode, types.RecordAtProxy} { + t.Run(mode, func(t *testing.T) { + rc := types.DefaultSessionRecordingConfig() + rc.SetMode(mode) + + _, err := auth.UpsertSessionRecordingConfig(ctx, rc) + require.NoError(t, err) + + // Start session. + term, errCh := startSession(suite.Me.Username) + + // Validate that the session tracker exists and contains + // the correct target address. + var sessionID string + require.EventuallyWithT(t, func(t *assert.CollectT) { + trackers, err := auth.GetActiveSessionTrackers(ctx) + require.NoError(t, err) + require.Len(t, trackers, 1) + require.Equal(t, helpers.HostID, trackers[0].GetAddress()) + sessionID = trackers[0].GetSessionID() + }, 30*time.Second, 100*time.Millisecond) + + // Wait for the session to terminate without error. + term.Type("exit\n\r") + require.NoError(t, waitForError(errCh, 30*time.Second)) + + // Manually clean up the tracker for the session to prevent + // it leaking into the next test case. + require.NoError(t, auth.RemoveSessionTracker(ctx, sessionID)) + }) + } +} + func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { tests := []struct { rootRecordingMode string diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 9dffe572b6fab..c4aa5c67c3b34 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -471,13 +471,10 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net DataDir: s.srv.Config.DataDir, Address: params.Address, UseTunnel: useTunnel, - HostUUID: s.srv.ID, + ProxyUUID: s.srv.ID, Emitter: s.srv.Config.Emitter, ParentContext: s.srv.Context, LockWatcher: s.srv.LockWatcher, - TargetID: params.ServerID, - TargetAddr: params.To.String(), - TargetHostname: params.Address, TargetServer: params.TargetServer, Clock: s.clock, } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 2f6b211cb36dc..75bf261e1d377 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -913,13 +913,10 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne Address: params.Address, UseTunnel: UseTunnel(s.logger, targetConn), FIPS: s.srv.FIPS, - HostUUID: s.srv.ID, + ProxyUUID: s.srv.ID, Emitter: s.srv.Config.Emitter, ParentContext: s.srv.Context, LockWatcher: s.srv.LockWatcher, - TargetID: params.ServerID, - TargetAddr: params.To.String(), - TargetHostname: params.Address, TargetServer: params.TargetServer, Clock: s.clock, } diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index fd05ed2931f70..44ad1571c43d5 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -701,7 +701,7 @@ func (h *AuthHandlers) hostKeyCallback(hostname string, remote net.Addr, key ssh ctx := h.c.Server.Context() // For SubKindOpenSSHEICENode we use SSH Keys (EC2 does not support Certificates in ec2.SendSSHPublicKey). - if h.c.Server.TargetMetadata().ServerSubKind == types.SubKindOpenSSHEICENode { + if h.c.Server.GetInfo().GetSubKind() == types.SubKindOpenSSHEICENode { return nil } diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 0324ac96833c0..ebc4704719121 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -155,6 +155,7 @@ type Server interface { GetClock() clockwork.Clock // GetInfo returns a services.Server that represents this server. + // In the case of the Proxy forwarder, this is the node target. GetInfo() types.Server // UseTunnel used to determine if this node has connected to this cluster @@ -189,8 +190,8 @@ type Server interface { // support or not. GetSELinuxEnabled() bool - // TargetMetadata returns metadata about the session target node. - TargetMetadata() apievents.ServerMetadata + // EventMetadata returns [events.ServerMetadata] for this server. + EventMetadata() apievents.ServerMetadata } // IdentityContext holds all identity information associated with the user @@ -483,7 +484,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s clientIdleTimeout: clientIdleTimeout, cancelContext: cancelContext, cancel: cancel, - ServerSubKind: srv.TargetMetadata().ServerSubKind, + ServerSubKind: srv.GetInfo().GetSubKind(), } child.Logger = slog.With( @@ -903,7 +904,7 @@ func (c *ServerContext) reportStats(conn utils.Stater) { Type: events.SessionDataEvent, Code: events.SessionDataCode, }, - ServerMetadata: c.srv.TargetMetadata(), + ServerMetadata: c.srv.EventMetadata(), SessionMetadata: c.GetSessionMetadata(), UserMetadata: c.Identity.GetUserMetadata(), ConnectionMetadata: apievents.ConnectionMetadata{ diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 2b65d301a0cc9..8daf6c7c1c31f 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -264,7 +264,7 @@ func (e *localExec) transformSecureCopy() error { Time: time.Now(), }, UserMetadata: e.Ctx.Identity.GetUserMetadata(), - ServerMetadata: e.Ctx.GetServer().TargetMetadata(), + ServerMetadata: e.Ctx.GetServer().EventMetadata(), Error: err.Error(), }) return trace.Wrap(err) @@ -369,7 +369,7 @@ func (e *remoteExec) Start(ctx context.Context, ch ssh.Channel) (*ExecResult, er Time: time.Now(), }, UserMetadata: e.ctx.Identity.GetUserMetadata(), - ServerMetadata: e.ctx.GetServer().TargetMetadata(), + ServerMetadata: e.ctx.GetServer().EventMetadata(), Error: err.Error(), }) return nil, trace.Wrap(err) @@ -435,7 +435,7 @@ func (e *remoteExec) PID() int { // instead of ctx.srv. func emitExecAuditEvent(ctx *ServerContext, cmd string, execErr error) { // Create common fields for event. - serverMeta := ctx.GetServer().TargetMetadata() + serverMeta := ctx.GetServer().EventMetadata() sessionMeta := ctx.GetSessionMetadata() userMeta := ctx.Identity.GetUserMetadata() diff --git a/lib/srv/exec_test.go b/lib/srv/exec_test.go index f79242275d1ca..086a43e7985a2 100644 --- a/lib/srv/exec_test.go +++ b/lib/srv/exec_test.go @@ -64,8 +64,6 @@ func TestEmitExecAuditEvent(t *testing.T) { rec, ok := scx.session.recorder.(*mockRecorder) require.True(t, ok) - scx.GetServer().TargetMetadata() - expectedUsr, err := user.Current() require.NoError(t, err) expectedHostname := "testHost" diff --git a/lib/srv/forward/sftp.go b/lib/srv/forward/sftp.go index f158ee9f1f853..9e9a60a60bd79 100644 --- a/lib/srv/forward/sftp.go +++ b/lib/srv/forward/sftp.go @@ -94,7 +94,7 @@ func (p *SFTPProxy) Serve() error { Code: events.SFTPSummaryCode, Time: time.Now(), }, - ServerMetadata: scx.GetServer().TargetMetadata(), + ServerMetadata: scx.GetServer().EventMetadata(), SessionMetadata: scx.GetSessionMetadata(), UserMetadata: scx.Identity.GetUserMetadata(), ConnectionMetadata: apievents.ConnectionMetadata{ @@ -230,7 +230,7 @@ func (h *proxyHandlers) sendSFTPEvent(req *sftp.Request, reqErr error) { } else if reqErr != nil { h.logger.DebugContext(req.Context(), "failed handling SFTP request", "request", req.Method, "error", reqErr) } - event.ServerMetadata = h.scx.GetServer().TargetMetadata() + event.ServerMetadata = h.scx.GetServer().EventMetadata() event.SessionMetadata = h.scx.GetSessionMetadata() event.UserMetadata = h.scx.Identity.GetUserMetadata() event.ConnectionMetadata = apievents.ConnectionMetadata{ diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index b4bffc04c99aa..834a1ef7609b4 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -30,7 +30,6 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" @@ -83,8 +82,6 @@ import ( type Server struct { logger *slog.Logger - id string - // targetConn is the TCP connection to the remote host. targetConn net.Conn @@ -155,9 +152,9 @@ type Server struct { clock clockwork.Clock - // hostUUID is the UUID of the underlying proxy that the forwarding server + // proxyUUID is the UUID of the underlying proxy that the forwarding server // is running in. - hostUUID string + proxyUUID string // closeContext and closeCancel are used to signal to the outside // world that this server is closed @@ -174,9 +171,6 @@ type Server struct { // of starting spans. tracerProvider oteltrace.TracerProvider - // TODO(Joerger): Remove in favor of targetServer, which has more accurate values. - targetID, targetAddr, targetHostname string - // targetServer is the host that the connection is being established for. targetServer types.Server } @@ -229,9 +223,9 @@ type ServerConfig struct { // configuration. FIPS bool - // HostUUID is the UUID of the underlying proxy that the forwarding server + // ProxyUUID is the UUID of the underlying proxy that the forwarding server // is running in. - HostUUID string + ProxyUUID string // Emitter is audit events emitter Emitter events.StreamEmitter @@ -247,9 +241,6 @@ type ServerConfig struct { // of starting spans. TracerProvider oteltrace.TracerProvider - // TODO(Joerger): Remove in favor of TargetServer, which has more accurate values. - TargetID, TargetAddr, TargetHostname string - // TargetServer is the host that the connection is being established for. TargetServer types.Server } @@ -331,7 +322,6 @@ func New(c ServerConfig) (*Server, error) { "src_addr", c.SrcAddr.String(), "dst_addr", c.DstAddr.String(), ), - id: uuid.New().String(), targetConn: c.TargetConn, serverConn: utils.NewTrackingConn(serverConn), clientConn: clientConn, @@ -344,14 +334,11 @@ func New(c ServerConfig) (*Server, error) { authService: c.LocalAuthClient, dataDir: c.DataDir, clock: c.Clock, - hostUUID: c.HostUUID, + proxyUUID: c.ProxyUUID, StreamEmitter: c.Emitter, parentContext: c.ParentContext, lockWatcher: c.LockWatcher, tracerProvider: c.TracerProvider, - targetID: c.TargetID, - targetAddr: c.TargetAddr, - targetHostname: c.TargetHostname, targetServer: c.TargetServer, } @@ -397,16 +384,18 @@ func New(c ServerConfig) (*Server, error) { return s, nil } -// TargetMetadata returns metadata about the forwarding target. -func (s *Server) TargetMetadata() apievents.ServerMetadata { +// EventMetadata returns metadata about the forwarding target. +func (s *Server) EventMetadata() apievents.ServerMetadata { + serverInfo := s.GetInfo() return apievents.ServerMetadata{ ServerVersion: teleport.Version, - ServerNamespace: s.GetNamespace(), - ServerID: s.targetID, - ServerAddr: s.targetAddr, - ServerHostname: s.targetHostname, - ForwardedBy: s.hostUUID, - ServerSubKind: s.targetServer.GetSubKind(), + ServerNamespace: serverInfo.GetNamespace(), + ServerID: serverInfo.GetName(), + ServerAddr: serverInfo.GetAddr(), + ServerLabels: serverInfo.GetAllLabels(), + ServerHostname: serverInfo.GetHostname(), + ServerSubKind: serverInfo.GetSubKind(), + ForwardedBy: s.proxyUUID, } } @@ -421,15 +410,15 @@ func (s *Server) GetDataDir() string { return s.dataDir } -// ID returns the ID of the proxy that creates the in-memory forwarding server. +// ID returns the UUID of the server targeted by the forwarding server. func (s *Server) ID() string { - return s.id + return s.targetServer.GetName() } // HostUUID is the UUID of the underlying proxy that the forwarding server // is running in. func (s *Server) HostUUID() string { - return s.hostUUID + return s.proxyUUID } // GetNamespace returns the namespace the forwarding server resides in. @@ -502,19 +491,39 @@ func (s *Server) GetSELinuxEnabled() bool { return false } -// GetInfo returns a services.Server that represents this server. +// GetInfo returns a services.Server that represents the target server. func (s *Server) GetInfo() types.Server { - return &types.ServerV2{ + return s.getBasicInfo() +} + +func (s *Server) getBasicInfo() *types.ServerV2 { + // Only set the address for non-tunnel nodes. + var addr string + if !s.targetServer.GetUseTunnel() { + addr = s.targetServer.GetAddr() + } + + srv := &types.ServerV2{ Kind: types.KindNode, + SubKind: s.targetServer.GetSubKind(), Version: types.V2, Metadata: types.Metadata{ - Name: s.ID(), - Namespace: s.GetNamespace(), + Name: s.targetServer.GetName(), + Namespace: s.targetServer.GetNamespace(), + Labels: s.targetServer.GetLabels(), }, Spec: types.ServerSpecV2{ - Addr: s.AdvertiseAddr(), + CmdLabels: types.LabelsToV2(s.targetServer.GetCmdLabels()), + Addr: addr, + Hostname: s.targetServer.GetHostname(), + UseTunnel: s.useTunnel, + Version: teleport.Version, + ProxyIDs: s.targetServer.GetProxyIDs(), + PublicAddrs: s.targetServer.GetPublicAddrs(), }, } + + return srv } // Dial returns the client connection created by pipeAddrConn. @@ -1470,7 +1479,7 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R Time: time.Now(), }, UserMetadata: serverContext.Identity.GetUserMetadata(), - ServerMetadata: serverContext.GetServer().TargetMetadata(), + ServerMetadata: serverContext.GetServer().EventMetadata(), Error: err.Error(), }) return trace.Wrap(err) diff --git a/lib/srv/forward/subsystem.go b/lib/srv/forward/subsystem.go index fbcb4957183e4..0c76d3f2a2a25 100644 --- a/lib/srv/forward/subsystem.go +++ b/lib/srv/forward/subsystem.go @@ -164,7 +164,7 @@ func (r *remoteSubsystem) emitAuditEvent(ctx context.Context, err error) { RemoteAddr: r.serverContext.RemoteClient.RemoteAddr().String(), }, Name: r.subsystemName, - ServerMetadata: r.serverContext.GetServer().TargetMetadata(), + ServerMetadata: r.serverContext.GetServer().EventMetadata(), } if err != nil { diff --git a/lib/srv/git/forward.go b/lib/srv/git/forward.go index 1427e7173e52c..0777192906016 100644 --- a/lib/srv/git/forward.go +++ b/lib/srv/git/forward.go @@ -566,7 +566,7 @@ func (s *ForwardServer) makeGitCommandEvent(sctx *sessionContext, command string RemoteAddr: sctx.ServerConn.RemoteAddr().String(), LocalAddr: sctx.ServerConn.LocalAddr().String(), }, - ServerMetadata: s.TargetMetadata(), + ServerMetadata: s.EventMetadata(), } if err != nil { event.Metadata.Code = events.GitCommandFailureCode @@ -663,7 +663,7 @@ func makeRemoteSigner(ctx context.Context, cfg *ForwardServerConfig, identityCtx func (s *ForwardServer) Context() context.Context { return s.cfg.ParentContext } -func (s *ForwardServer) TargetMetadata() apievents.ServerMetadata { +func (s *ForwardServer) EventMetadata() apievents.ServerMetadata { return apievents.ServerMetadata{ ServerVersion: teleport.Version, ServerNamespace: s.cfg.TargetServer.GetNamespace(), diff --git a/lib/srv/mock_test.go b/lib/srv/mock_test.go index 61546e42e783a..78f1980f727ea 100644 --- a/lib/srv/mock_test.go +++ b/lib/srv/mock_test.go @@ -253,7 +253,7 @@ func (m *mockServer) GetInfo() types.Server { } } -func (m *mockServer) TargetMetadata() apievents.ServerMetadata { +func (m *mockServer) EventMetadata() apievents.ServerMetadata { return apievents.ServerMetadata{ ServerID: "123", ForwardedBy: "abc", diff --git a/lib/srv/regular/sftp.go b/lib/srv/regular/sftp.go index b58efa842999d..c47278b382872 100644 --- a/lib/srv/regular/sftp.go +++ b/lib/srv/regular/sftp.go @@ -77,7 +77,7 @@ func (s *sftpSubsys) Start(ctx context.Context, Time: time.Now(), }, UserMetadata: serverCtx.Identity.GetUserMetadata(), - ServerMetadata: serverCtx.GetServer().TargetMetadata(), + ServerMetadata: serverCtx.GetServer().EventMetadata(), Error: srv.ErrNodeFileCopyingNotPermitted.Error(), }) return srv.ErrNodeFileCopyingNotPermitted @@ -168,7 +168,7 @@ func (s *sftpSubsys) Start(ctx context.Context, defer auditPipeOut.Close() // Create common fields for events - serverMeta := serverCtx.GetServer().TargetMetadata() + serverMeta := serverCtx.GetServer().EventMetadata() sessionMeta := serverCtx.GetSessionMetadata() userMeta := serverCtx.Identity.GetUserMetadata() connectionMeta := apievents.ConnectionMetadata{ diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 84910ab87c7df..8f588dc646bb8 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -258,15 +258,17 @@ type Server struct { scope string } -// TargetMetadata returns metadata about the server. -func (s *Server) TargetMetadata() apievents.ServerMetadata { +// EventMetadata returns metadata about the server. +func (s *Server) EventMetadata() apievents.ServerMetadata { + serverInfo := s.GetInfo() return apievents.ServerMetadata{ ServerVersion: teleport.Version, - ServerNamespace: s.GetNamespace(), - ServerID: s.ID(), - ServerAddr: s.Addr(), - ServerLabels: s.getAllLabels(), - ServerHostname: s.hostname, + ServerNamespace: serverInfo.GetNamespace(), + ServerID: serverInfo.GetName(), + ServerAddr: serverInfo.GetAddr(), + ServerLabels: serverInfo.GetAllLabels(), + ServerHostname: serverInfo.GetHostname(), + ServerSubKind: serverInfo.GetSubKind(), } } @@ -1110,18 +1112,6 @@ func (s *Server) getDynamicLabels() map[string]types.CommandLabelV2 { return types.LabelsToV2(s.dynamicLabels.Get()) } -// getAllLabels return a combination of static and dynamic labels. -func (s *Server) getAllLabels() map[string]string { - lmap := make(map[string]string) - for key, value := range s.getStaticLabels() { - lmap[key] = value - } - for key, cmd := range s.getDynamicLabels() { - lmap[key] = cmd.Result - } - return lmap -} - // GetInfo returns a services.Server that represents this server. func (s *Server) GetInfo() types.Server { return s.getBasicInfo() @@ -2461,7 +2451,7 @@ func (s *Server) parseSubsystemRequest(ctx context.Context, req *ssh.Request, se Time: time.Now(), }, UserMetadata: serverContext.Identity.GetUserMetadata(), - ServerMetadata: serverContext.GetServer().TargetMetadata(), + ServerMetadata: serverContext.GetServer().EventMetadata(), Error: err.Error(), }) return nil, trace.Wrap(err) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 8f16d56aa1bd2..a4a164be8ca90 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -3092,7 +3092,7 @@ func TestTargetMetadata(t *testing.T) { serverOptions...) require.NoError(t, err) - metadata := sshSrv.TargetMetadata() + metadata := sshSrv.EventMetadata() require.Equal(t, nodeID, metadata.ServerID) require.Equal(t, apidefaults.Namespace, metadata.ServerNamespace) require.Empty(t, metadata.ServerAddr) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index e8581696cdd35..66ae5e07d897e 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -881,7 +881,7 @@ func newSession(ctx context.Context, r *SessionRegistry, scx *ServerContext, ch cluster: scx.Identity.OriginClusterName, }, displayParticipantRequirements: utils.AsBool(scx.env[teleport.EnvSSHSessionDisplayParticipantRequirements]), - serverMeta: scx.srv.TargetMetadata(), + serverMeta: scx.srv.EventMetadata(), } sess.io.OnWriteError = sess.onWriteErrorCallback(sessionRecordingMode) @@ -1439,7 +1439,7 @@ func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *p Emitter: s.emitter, Namespace: scx.srv.GetNamespace(), SessionID: s.id.String(), - ServerID: scx.srv.HostUUID(), + ServerID: scx.srv.ID(), ServerHostname: scx.srv.GetInfo().GetHostname(), Login: scx.Identity.Login, User: scx.Identity.TeleportUser, @@ -1657,7 +1657,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve Emitter: s.emitter, Namespace: scx.srv.GetNamespace(), SessionID: string(s.id), - ServerID: scx.srv.HostUUID(), + ServerID: scx.srv.ID(), ServerHostname: scx.srv.GetInfo().GetHostname(), Login: scx.Identity.Login, User: scx.Identity.TeleportUser, @@ -2366,7 +2366,7 @@ func (s *session) trackSession(ctx context.Context, teleportUser string, policyS Kind: string(types.SSHSessionKind), State: types.SessionState_SessionStatePending, Hostname: s.serverMeta.ServerHostname, - Address: s.serverMeta.ServerID, + Address: s.scx.srv.ID(), ClusterName: s.scx.ClusterName, Login: s.login, HostUser: teleportUser, @@ -2382,7 +2382,7 @@ func (s *session) trackSession(ctx context.Context, teleportUser string, policyS LastActive: s.registry.clock.Now().UTC(), }, }, - HostID: s.registry.Srv.ID(), + HostID: s.registry.Srv.HostUUID(), TargetSubKind: s.serverMeta.ServerSubKind, InitialCommand: initialCommand, } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 46596df72628f..c1d9236ddc0c4 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -2305,7 +2305,7 @@ func TestTerminalRouting(t *testing.T) { sess := term.GetSession() - metadata := tt.target.TargetMetadata() + metadata := tt.target.EventMetadata() require.Equal(t, metadata.ServerID, sess.ServerID) require.Equal(t, metadata.ServerHostname, sess.ServerHostname) @@ -8638,7 +8638,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula AccessPoint: client, LockWatcher: proxyLockWatcher, Clock: clock, - ServerID: proxyID, + ServerID: node.ID(), Emitter: client, EmitterContext: ctx, Logger: logtest.NewLogger(), From 9a85d8a86b23c31bc0112e56a20a9e88e7805587 Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Fri, 26 Sep 2025 08:13:26 -0700 Subject: [PATCH 2/4] Replace flaky test with more straightforward event metadata test. (#59610) --- integration/integration_test.go | 78 ---------------------- lib/srv/forward/sshserver.go | 4 -- lib/srv/forward/sshserver_test.go | 103 ++++++++++++++++++++++++++++++ lib/srv/regular/sshserver_test.go | 2 +- 4 files changed, 104 insertions(+), 83 deletions(-) diff --git a/integration/integration_test.go b/integration/integration_test.go index 86265202c6467..d3f0ce858da45 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -183,7 +183,6 @@ func TestIntegrations(t *testing.T) { t.Run("PAM", suite.bind(testPAM)) t.Run("PortForwarding", suite.bind(testPortForwarding)) t.Run("ProxyHostKeyCheck", suite.bind(testProxyHostKeyCheck)) - t.Run("RecordingModesSessionTrackers", suite.bind(testRecordingModesSessionTrackers)) t.Run("ReverseTunnelCollapse", suite.bind(testReverseTunnelCollapse)) t.Run("RotateRollback", suite.bind(testRotateRollback)) t.Run("RotateSuccess", suite.bind(testRotateSuccess)) @@ -1028,83 +1027,6 @@ func testSessionRecordingModes(t *testing.T, suite *integrationTestSuite) { } } -func testRecordingModesSessionTrackers(t *testing.T, suite *integrationTestSuite) { - ctx := t.Context() - - cfg := suite.defaultServiceConfig() - cfg.Auth.Enabled = true - cfg.Proxy.DisableWebService = true - cfg.Proxy.DisableWebInterface = true - cfg.Proxy.Enabled = true - cfg.SSH.Enabled = true - - teleport := suite.NewTeleportWithConfig(t, nil, nil, cfg) - defer teleport.StopAll() - - // startSession starts an interactive session, users must terminate the - // session by typing "exit" in the terminal. - startSession := func(username string) (*Terminal, chan error) { - term := NewTerminal(250) - errCh := make(chan error) - - go func() { - cl, err := teleport.NewClient(helpers.ClientConfig{ - Login: username, - Cluster: helpers.Site, - Host: Host, - }) - if err != nil { - errCh <- trace.Wrap(err) - return - } - cl.Stdout = term - cl.Stdin = term - - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - errCh <- cl.SSH(ctx, nil) - }() - - return term, errCh - } - - err := teleport.WaitForNodeCount(ctx, helpers.Site, 1) - require.NoError(t, err) - - auth := teleport.Process.GetAuthServer() - for _, mode := range []string{types.RecordAtNode, types.RecordAtProxy} { - t.Run(mode, func(t *testing.T) { - rc := types.DefaultSessionRecordingConfig() - rc.SetMode(mode) - - _, err := auth.UpsertSessionRecordingConfig(ctx, rc) - require.NoError(t, err) - - // Start session. - term, errCh := startSession(suite.Me.Username) - - // Validate that the session tracker exists and contains - // the correct target address. - var sessionID string - require.EventuallyWithT(t, func(t *assert.CollectT) { - trackers, err := auth.GetActiveSessionTrackers(ctx) - require.NoError(t, err) - require.Len(t, trackers, 1) - require.Equal(t, helpers.HostID, trackers[0].GetAddress()) - sessionID = trackers[0].GetSessionID() - }, 30*time.Second, 100*time.Millisecond) - - // Wait for the session to terminate without error. - term.Type("exit\n\r") - require.NoError(t, waitForError(errCh, 30*time.Second)) - - // Manually clean up the tracker for the session to prevent - // it leaking into the next test case. - require.NoError(t, auth.RemoveSessionTracker(ctx, sessionID)) - }) - } -} - func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { tests := []struct { rootRecordingMode string diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 834a1ef7609b4..177c6f6aea062 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -493,10 +493,6 @@ func (s *Server) GetSELinuxEnabled() bool { // GetInfo returns a services.Server that represents the target server. func (s *Server) GetInfo() types.Server { - return s.getBasicInfo() -} - -func (s *Server) getBasicInfo() *types.ServerV2 { // Only set the address for non-tunnel nodes. var addr string if !s.targetServer.GetUseTunnel() { diff --git a/lib/srv/forward/sshserver_test.go b/lib/srv/forward/sshserver_test.go index e46201010d898..cac12932e1668 100644 --- a/lib/srv/forward/sshserver_test.go +++ b/lib/srv/forward/sshserver_test.go @@ -28,12 +28,15 @@ import ( "sync/atomic" "testing" + "github.com/google/uuid" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/authclient" @@ -380,3 +383,103 @@ func TestServerConfigCheckDefaults(t *testing.T) { }) } } + +func TestEventMetadata(t *testing.T) { + nodeID := uuid.NewString() + proxyID := uuid.NewString() + + for _, tt := range []struct { + name string + subkind string + spec types.ServerSpecV2 + labels map[string]string + expectMetadata events.ServerMetadata + }{ + { + name: "tunnel node", + labels: map[string]string{ + "stcLabel": "stcResult", + }, + spec: types.ServerSpecV2{ + Addr: "127.0.0.1:3022", + CmdLabels: map[string]types.CommandLabelV2{ + "cmdLabel": {Result: "cmdResult"}, + }, + Hostname: "server01", + UseTunnel: true, + }, + expectMetadata: events.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: nodeID, + ServerNamespace: apidefaults.Namespace, + ServerAddr: "", + ServerHostname: "server01", + ServerLabels: map[string]string{ + "stcLabel": "stcResult", + "cmdLabel": "cmdResult", + }, + ServerSubKind: types.SubKindTeleportNode, + ForwardedBy: proxyID, + }, + }, { + name: "tunnel node", + labels: map[string]string{ + "stcLabel": "stcResult", + }, + spec: types.ServerSpecV2{ + Addr: "127.0.0.1:3022", + CmdLabels: map[string]types.CommandLabelV2{ + "cmdLabel": {Result: "cmdResult"}, + }, + Hostname: "server01", + }, + expectMetadata: events.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: nodeID, + ServerNamespace: apidefaults.Namespace, + ServerAddr: "127.0.0.1:3022", + ServerHostname: "server01", + ServerLabels: map[string]string{ + "stcLabel": "stcResult", + "cmdLabel": "cmdResult", + }, + ServerSubKind: types.SubKindTeleportNode, + ForwardedBy: proxyID, + }, + }, { + name: "agentless node", + subkind: types.SubKindOpenSSHNode, + labels: map[string]string{ + "stcLabel": "stcResult", + }, + spec: types.ServerSpecV2{ + Addr: "openssh.example.com:22", + Hostname: "agentless-host", + }, + expectMetadata: events.ServerMetadata{ + ServerVersion: teleport.Version, + ServerID: nodeID, + ServerNamespace: apidefaults.Namespace, + ServerAddr: "openssh.example.com:22", + ServerHostname: "agentless-host", + ServerLabels: map[string]string{ + "stcLabel": "stcResult", + }, + ServerSubKind: types.SubKindOpenSSHNode, + ForwardedBy: proxyID, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + targetServer, err := types.NewNode(nodeID, tt.subkind, tt.spec, tt.labels) + require.NoError(t, err) + + forwardSrv := &Server{ + proxyUUID: proxyID, + targetServer: targetServer, + } + + require.EqualValues(t, tt.expectMetadata, forwardSrv.EventMetadata()) + }) + } +} diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index a4a164be8ca90..6d515a6ab7371 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -3025,7 +3025,7 @@ func TestHandlePuTTYWinadj(t *testing.T) { require.Equal(t, "hello once more\n", string(out)) } -func TestTargetMetadata(t *testing.T) { +func TestEventMetadata(t *testing.T) { ctx := context.Background() testServer, err := authtest.NewTestServer(authtest.ServerConfig{ Auth: authtest.AuthServerConfig{ From f9d1749b8af1a922af28bed68af9db9cd78be45e Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Fri, 3 Oct 2025 16:44:16 -0700 Subject: [PATCH 3/4] Make SSH session client provide session params upfront rather than with synchronous `envs@goteleport.com` requests (#59206) --- api/observability/tracing/ssh/client.go | 73 ++++++++++++++- integration/integration_test.go | 6 +- lib/client/api.go | 19 +--- lib/client/client.go | 52 ++++++----- lib/client/client_test.go | 12 +-- lib/client/session.go | 64 ++++++++----- lib/srv/ctx.go | 73 +++++++-------- lib/srv/forward/sshserver.go | 24 ++++- lib/srv/git/forward.go | 11 ++- lib/srv/regular/sshserver.go | 28 ++++-- lib/srv/regular/sshserver_test.go | 118 +++++++++++++++++++++++- lib/srv/sess.go | 40 +++----- lib/srv/sess_test.go | 13 ++- lib/srv/term.go | 6 +- lib/srv/termhandlers.go | 4 +- lib/sshutils/sftp/http.go | 6 +- lib/sshutils/sftp/remote.go | 2 +- lib/web/files.go | 24 +++-- lib/web/terminal.go | 7 +- 19 files changed, 395 insertions(+), 187 deletions(-) diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index 714a8607f429e..28f0f8d014726 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -22,6 +22,7 @@ import ( "sync" "sync/atomic" + "github.com/google/uuid" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" @@ -29,6 +30,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" ) // Client is a wrapper around ssh.Client that adds tracing support. @@ -184,9 +186,65 @@ func (c *Client) OpenChannel( }, reqs, err } -// NewSession creates a new SSH session that is passed tracing context -// so that spans may be correlated properly over the ssh connection. +// SessionParams are session parameters supported by Teleport to provide additional +// session context or parameters to the server. +type SessionParams struct { + // WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server. + WebProxyAddr string + // Reason is a reason attached to started sessions meant to describe their intent. + Reason string + // Invited is a list of people invited to a session. + Invited []string + // DisplayParticipantRequirements is set if debug information about participants requirements + // should be printed in moderated sessions. + DisplayParticipantRequirements bool + // JoinSessionID is the ID of a session to join. + JoinSessionID string + // JoinMode is the participant mode to join the session with. + // Required if JoinSessionID is set. + JoinMode types.SessionParticipantMode + // ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session + // to check for valid FileTransferRequests. + ModeratedSessionID string +} + +// ParseSessionParams unmarshals session parameters which have been [ssh.Marshal]ed by the client +// and provided as extra data in the session channel request. If the provided data is empty, nil params +// will be returned with a nil error. +func ParseSessionParams(data []byte) (*SessionParams, error) { + if len(data) == 0 { + return nil, nil + } + + var params SessionParams + if err := ssh.Unmarshal(data, ¶ms); err != nil { + return nil, trace.Wrap(err) + } + + if params.JoinSessionID != "" { + if _, err := uuid.Parse(params.JoinSessionID); err != nil { + return nil, trace.Wrap(err, "failed to parse join session ID: %v", params.JoinSessionID) + } + + switch params.JoinMode { + case types.SessionModeratorMode, types.SessionObserverMode, types.SessionPeerMode: + default: + return nil, trace.BadParameter("Unrecognized session participant mode: %q", params.JoinMode) + } + } + + return ¶ms, nil +} + +// NewSession creates a new SSH session. This session is passed a tracing context so that +// spans may be correlated properly over the ssh connection. func (c *Client) NewSession(ctx context.Context) (*Session, error) { + return c.NewSessionWithParams(ctx, nil) +} + +// NewSessionWithParams creates a new SSH session with the given (optional) params. This session is +// passed a tracing context so that spans may be correlated properly over the ssh connection. +func (c *Client) NewSessionWithParams(ctx context.Context, sessionParams *SessionParams) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) ctx, span := tracer.Start( @@ -213,9 +271,16 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { contexts: make(map[string][]context.Context), } + // If we are connected to a Teleport server, send session params in the session request. + // If the server does not support session parameters in the extra data, it will be ignored. + var sessionData []byte + if sessionParams != nil && c.capability == tracingSupported { + sessionData = ssh.Marshal(sessionParams) + } + // open a session manually so we can take ownership of the // requests chan - ch, reqs, err := wrapper.OpenChannel("session", nil) + ch, reqs, err := wrapper.OpenChannel("session", sessionData) if err != nil { return nil, trace.Wrap(err) } @@ -236,7 +301,7 @@ func (c *Client) NewSession(ctx context.Context) (*Session, error) { } // RequestHandlerFn is an ssh request handler function. -type RequestHandlerFn func(ctx context.Context, ch *ssh.Request) +type RequestHandlerFn func(ctx context.Context, req *ssh.Request) // HandleSessionRequest registers a handler for any incoming [ssh.Request] matching the // provided type within a session. If the type is already being handled, an error is returned. diff --git a/integration/integration_test.go b/integration/integration_test.go index d3f0ce858da45..33772ccecdcd8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -1120,7 +1120,7 @@ func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { ) assert.NoError(t, err) - errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) + errCh <- nodeClient.RunInteractiveShell(ctx, "", "", nil) assert.NoError(t, nodeClient.Close()) }() @@ -7983,7 +7983,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { isNilOrEOFErr(t, transferSess.Close()) }) - err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + err = transferSess.Setenv(ctx, string(telesftp.EnvModeratedSessionID), sessTracker.GetSessionID()) require.NoError(t, err) err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem) @@ -8045,7 +8045,7 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { require.NoError(t, transferSess.Close()) }) - err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + err = transferSess.Setenv(ctx, string(telesftp.EnvModeratedSessionID), sessTracker.GetSessionID()) require.NoError(t, err) // Test that only operations needed to complete the download diff --git a/lib/client/api.go b/lib/client/api.go index 76df566e80ad6..048cf40037fef 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -94,7 +94,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -2303,7 +2302,7 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt // Reuse the existing nodeClient we connected above. return nodeClient.RunCommand(ctx, command) } - return trace.Wrap(nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil)) + return trace.Wrap(nodeClient.RunInteractiveShell(ctx, "", "", nil)) } func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context, clt *ClusterClient, nodes []TargetNode, command []string) error { @@ -2457,7 +2456,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan } // running shell with a given session means "join" it: - err = nc.RunInteractiveShell(ctx, mode, session, beforeStart) + err = nc.RunInteractiveShell(ctx, sessionID.String(), mode, beforeStart) return trace.Wrap(err) } @@ -3155,20 +3154,6 @@ func (tc *TeleportClient) writeCommandResults(nodes []execResult) error { return nil } -func (tc *TeleportClient) newSessionEnv() map[string]string { - env := map[string]string{ - teleport.SSHSessionWebProxyAddr: tc.WebProxyAddr, - } - if tc.SessionID != "" { - env[sshutils.SessionEnvVar] = tc.SessionID - } - - for key, val := range tc.ExtraEnvs { - env[key] = val - } - return env -} - // getProxyLogin determines which SSH principal to use when connecting to proxy. func (tc *TeleportClient) getProxySSHPrincipal() string { if tc.ProxySSHPrincipal != "" { diff --git a/lib/client/client.go b/lib/client/client.go index bb3aca78fc330..4dfed45e5e074 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -28,7 +28,6 @@ import ( "net" "os" "path/filepath" - "strconv" "strings" "sync" "time" @@ -391,10 +390,10 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co return nc, nil } -// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr +// RunInteractiveShell creates or joins an interactive shell on the node and copies stdin/stdout/stderr // to and from the node and local shell. This will block until the interactive shell on the node // is terminated. -func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker, beforeStart func(io.Writer)) error { +func (c *NodeClient) RunInteractiveShell(ctx context.Context, joinSessionID string, joinMode types.SessionParticipantMode, beforeStart func(io.Writer)) error { ctx, span := c.Tracer.Start( ctx, "nodeClient/RunInteractiveShell", @@ -402,28 +401,21 @@ func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.Session ) defer span.End() - env := c.TC.newSessionEnv() - env[teleport.EnvSSHJoinMode] = string(mode) - env[teleport.EnvSSHSessionReason] = c.TC.Config.Reason - env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(c.TC.Config.DisplayParticipantRequirements) - encoded, err := json.Marshal(&c.TC.Config.Invited) - if err != nil { - return trace.Wrap(err) + sessionParams := &tracessh.SessionParams{ + WebProxyAddr: c.WebProxyAddr(), + Reason: c.TC.Config.Reason, + Invited: c.TC.Config.Invited, + DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements, + JoinSessionID: joinSessionID, + JoinMode: joinMode, } - env[teleport.EnvSSHSessionInvited] = string(encoded) - // Overwrite "SSH_SESSION_WEBPROXY_ADDR" with the public addr reported by the proxy. Otherwise, - // this would be set to the localhost addr (tc.WebProxyAddr) used for Web UI client connections. - if c.ProxyPublicAddr != "" && c.TC.WebProxyAddr != c.ProxyPublicAddr { - env[teleport.SSHSessionWebProxyAddr] = c.ProxyPublicAddr - } - - nodeSession, err := newSession(ctx, c, sessToJoin, env, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences) + nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, !c.TC.DisableEscapeSequences) if err != nil { return trace.Wrap(err) } - if err = nodeSession.runShell(ctx, mode, beforeStart, c.TC.OnShellCreated); err != nil { + if err = nodeSession.runShell(ctx, sessionParams, beforeStart, c.TC.OnShellCreated); err != nil { var exitErr *ssh.ExitError var exitMissingErr *ssh.ExitMissingError switch err := trace.Unwrap(err); { @@ -619,13 +611,19 @@ func (c *NodeClient) RunCommand(ctx context.Context, command []string, opts ...R } } - nodeSession, err := newSession(ctx, c, nil, c.TC.newSessionEnv(), c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences) + sessionParams := &tracessh.SessionParams{ + WebProxyAddr: c.WebProxyAddr(), + Reason: c.TC.Config.Reason, + Invited: c.TC.Config.Invited, + DisplayParticipantRequirements: c.TC.Config.DisplayParticipantRequirements, + } + + nodeSession, err := newSession(ctx, c, sessionParams, c.TC.Stdin, stdout, stderr, !c.TC.DisableEscapeSequences) if err != nil { return trace.Wrap(err) } defer nodeSession.Close() - - err = nodeSession.runCommand(ctx, types.SessionPeerMode, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) + err = nodeSession.runCommand(ctx, sessionParams, command, c.TC.OnShellCreated, c.TC.Config.InteractiveCommand) if err != nil { c.TC.SetExitStatus(getExitStatus(err)) } @@ -1009,3 +1007,13 @@ func GetPaginatedSessions(ctx context.Context, fromUTC, toUTC time.Time, pageSiz } return sessions, nil } + +// WebProxyAddr is the address of the proxy forwarding the SSH connection to the target server. +func (c *NodeClient) WebProxyAddr() string { + // Prioritize the public addr reported by the proxy. Otherwise, this would + // return the localhost addr used for Web UI client connections. + if c.ProxyPublicAddr != "" { + return c.ProxyPublicAddr + } + return c.TC.WebProxyAddr +} diff --git a/lib/client/client_test.go b/lib/client/client_test.go index f53c6ec317cb1..5ee7659c4233a 100644 --- a/lib/client/client_test.go +++ b/lib/client/client_test.go @@ -36,7 +36,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/lib/observability/tracing" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" ) @@ -48,12 +47,13 @@ func TestHelperFunctions(t *testing.T) { func TestNewSession(t *testing.T) { nc := &NodeClient{ + TC: &TeleportClient{}, Tracer: tracing.NoopProvider().Tracer("test"), } ctx := context.Background() // defaults: - ses, err := newSession(ctx, nc, nil, nil, nil, nil, nil, true) + ses, err := newSession(ctx, nc, nil, nil, nil, nil, true) require.NoError(t, err) require.NotNil(t, ses) require.Equal(t, nc, ses.NodeClient()) @@ -61,14 +61,6 @@ func TestNewSession(t *testing.T) { require.Equal(t, os.Stderr, ses.terminal.Stderr()) require.Equal(t, os.Stdout, ses.terminal.Stdout()) require.Equal(t, os.Stdin, ses.terminal.Stdin()) - - // pass environ map - env := map[string]string{ - sshutils.SessionEnvVar: "session-id", - } - ses, err = newSession(ctx, nc, nil, env, nil, nil, nil, true) - require.NoError(t, err) - require.NotNil(t, ses) } // TestProxyConnection verifies that client or server-side disconnect diff --git a/lib/client/session.go b/lib/client/session.go index 69cdbddb0c6ec..65556e9f8a029 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -20,12 +20,15 @@ package client import ( "context" + "encoding/json" "errors" "fmt" "io" + "maps" "net" "os" "os/signal" + "strconv" "strings" "sync" "sync/atomic" @@ -101,8 +104,7 @@ type NodeSession struct { // of another user func newSession(ctx context.Context, client *NodeClient, - joinSession types.SessionTracker, - env map[string]string, + sessionParams *tracessh.SessionParams, stdin io.Reader, stdout io.Writer, stderr io.Writer, @@ -116,8 +118,21 @@ func newSession(ctx context.Context, return nil, trace.Wrap(err) } - if env == nil { - env = make(map[string]string) + env := make(map[string]string) + maps.Copy(env, client.TC.ExtraEnvs) + + // TODO(Joerger): DELETE IN v20.0.0 - session params are provided in the session + // request as extra data rather than env vars. + if sessionParams != nil { + env[teleport.SSHSessionWebProxyAddr] = sessionParams.WebProxyAddr + env[teleport.EnvSSHJoinMode] = string(sessionParams.JoinMode) + env[teleport.EnvSSHSessionReason] = sessionParams.Reason + env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(sessionParams.DisplayParticipantRequirements) + encoded, err := json.Marshal(&sessionParams.Invited) + if err != nil { + return nil, trace.Wrap(err) + } + env[teleport.EnvSSHSessionInvited] = string(encoded) } ns := &NodeSession{ @@ -129,11 +144,11 @@ func newSession(ctx context.Context, terminal: term, shouldClearOnExit: client.FIPSEnabled || isFIPS(), } - // if we're joining an existing session, we need to assume that session's - // existing/current terminal size: - if joinSession != nil { - sessionID := joinSession.GetSessionID() - terminalSize, err := client.GetRemoteTerminalSize(ctx, sessionID) + + if sessionParams != nil && sessionParams.JoinSessionID != "" { + // if we're joining an existing session, we need to assume that session's + // existing/current terminal size: + terminalSize, err := client.GetRemoteTerminalSize(ctx, sessionParams.JoinSessionID) if err != nil { return nil, trace.Wrap(err) } @@ -143,10 +158,10 @@ func newSession(ctx context.Context, if err != nil { log.ErrorContext(ctx, "Failed to resize terminal", "error", err) } - } - ns.env[sshutils.SessionEnvVar] = sessionID + // TODO(Joerger): DELETE IN v20.0.0 - session env var is no longer used for session joining. + ns.env[sshutils.SessionEnvVar] = sessionParams.JoinSessionID } // Close the Terminal when finished. @@ -171,7 +186,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func(s *tracessh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, sessionParams *tracessh.SessionParams, sessionCallback func(s *tracessh.Session) error) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/regularSession", @@ -179,7 +194,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func( ) defer span.End() - session, err := ns.createServerSession(ctx) + session, err := ns.createServerSession(ctx, nil) if err != nil { return trace.Wrap(err) } @@ -191,7 +206,7 @@ func (ns *NodeSession) regularSession(ctx context.Context, sessionCallback func( type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context, sessionParams *tracessh.SessionParams) (*tracessh.Session, error) { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/createServerSession", @@ -199,7 +214,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi ) defer span.End() - sess, err := ns.nodeClient.Client.NewSession(ctx) + sess, err := ns.nodeClient.Client.NewSessionWithParams(ctx, sessionParams) if err != nil { return nil, trace.Wrap(err) } @@ -267,7 +282,7 @@ func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGett // interactiveSession creates an interactive session on the remote node, executes // the given callback on it, and waits for the session to end -func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.SessionParticipantMode, sessionCallback interactiveCallback) error { +func (ns *NodeSession) interactiveSession(ctx context.Context, sessionParams *tracessh.SessionParams, sessionCallback interactiveCallback) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/interactiveSession", @@ -281,7 +296,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio termType = teleport.SafeTerminalType } // create the server-side session: - sess, err := ns.createServerSession(ctx) + sess, err := ns.createServerSession(ctx, sessionParams) if err != nil { return trace.Wrap(err) } @@ -307,6 +322,11 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio ns.watchSignals(remoteTerm) } + mode := types.SessionPeerMode + if sessionParams != nil && sessionParams.JoinMode != "" { + mode = sessionParams.JoinMode + } + // start piping input into the remote shell and pipe the output from // the remote shell into stdout: ns.pipeInOut(ctx, remoteTerm, mode, sess) @@ -511,8 +531,8 @@ func (s *sessionWriter) Write(p []byte) (int, error) { } // runShell executes user's shell on the remote node under an interactive session -func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { +func (ns *NodeSession) runShell(ctx context.Context, sessionParams *tracessh.SessionParams, beforeStart func(io.Writer), shellCallback ShellCreatedCallback) error { + return ns.interactiveSession(ctx, sessionParams, func(s *tracessh.Session, shell io.ReadWriteCloser) error { w := &sessionWriter{ tshOut: ns.nodeClient.TC.Stdout, session: s, @@ -540,7 +560,7 @@ func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipa // runCommand executes a "exec" request either in interactive mode (with a // TTY attached) or non-intractive mode (no TTY). -func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionParticipantMode, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { +func (ns *NodeSession) runCommand(ctx context.Context, sessionParams *tracessh.SessionParams, cmd []string, shellCallback ShellCreatedCallback, interactive bool) error { ctx, span := ns.nodeClient.Tracer.Start( ctx, "nodeClient/runCommand", @@ -554,7 +574,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, sessionParams, func(s *tracessh.Session, term io.ReadWriteCloser) error { err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) @@ -581,7 +601,7 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, func(s *tracessh.Session) error { + return ns.regularSession(ctx, sessionParams, func(s *tracessh.Session) error { errCh := make(chan error, 1) go func() { errCh <- s.Run(ctx, strings.Join(cmd, " ")) diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index ebc4704719121..fe7e1187b5767 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -20,6 +20,7 @@ package srv import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -53,6 +54,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/envutils" @@ -314,6 +316,9 @@ type ServerContext struct { // term holds PTY if it was requested by the session. term Terminal + // sessionParams are parameters associated with this server session. + sessionParams *tracessh.SessionParams + // session holds the active session (if there's an active one). session *session @@ -436,7 +441,7 @@ type ServerContext struct { // the ServerContext is closed. The ctx parameter should be a child of the ctx // associated with the scope of the parent ConnectionContext to ensure that // cancellation of the ConnectionContext propagates to the ServerContext. -func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) { +func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, srv Server, identityContext IdentityContext, sessionParams *tracessh.SessionParams, monitorOpts ...func(*MonitorConfig)) (*ServerContext, error) { recConfig, err := srv.GetAccessPoint().GetSessionRecordingConfig(ctx) if err != nil { return nil, trace.Wrap(err) @@ -476,6 +481,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s id: int(atomic.AddInt32(&ctxID, int32(1))), env: make(map[string]string), srv: srv, + sessionParams: sessionParams, ExecResultCh: make(chan ExecResult, 10), SubsystemResultCh: make(chan SubsystemResult, 10), ClusterName: parent.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName], @@ -587,31 +593,6 @@ func (c *ServerContext) ID() int { return c.id } -// GetJoinParams gets join params if they are set. -// -// These params (env vars) are set synchronously between the "session" channel request -// and the "shell" / "exec" channel request. Therefore, these params are only guaranteed -// to be accurately set during and after the "shell" / "exec" channel request. -// -// TODO(Joerger): Rather than relying on the out-of-band env var params, we should -// provide session params upfront as extra data in the session channel request. -func (c *ServerContext) GetJoinParams() (string, types.SessionParticipantMode) { - c.mu.RLock() - defer c.mu.RUnlock() - - sid, found := c.getEnvLocked(sshutils.SessionEnvVar) - if !found { - return "", "" - } - - mode := types.SessionPeerMode // default - if modeString, found := c.getEnvLocked(teleport.EnvSSHJoinMode); found { - mode = types.SessionParticipantMode(modeString) - } - - return sid, mode -} - // SessionID returns the ID of the session in the context. // // This value is not set until during and after the "shell" / "exec" channel request. @@ -680,19 +661,35 @@ func (c *ServerContext) SetEnv(key, val string) { c.mu.Unlock() } -// GetEnv returns a environment variable within this context. -func (c *ServerContext) GetEnv(key string) (string, bool) { +// GetSessionParams gets session params for the current session. +func (c *ServerContext) GetSessionParams() tracessh.SessionParams { c.mu.RLock() defer c.mu.RUnlock() - return c.getEnvLocked(key) -} -func (c *ServerContext) getEnvLocked(key string) (string, bool) { - val, ok := c.env[key] - if ok { - return val, true + // Teleport ssh clients should provide session params upfront in the session channel request. + if c.sessionParams != nil { + return *c.sessionParams + } + + // If this is an old client, it will provide session params from + // env variables sometime between the session channel request and shell request. + // TODO(Joerger): DELETE IN v20.0.0 - just return empty params for an old Teleport client / openSSH client session. + sessionParams := tracessh.SessionParams{ + WebProxyAddr: c.env[teleport.SSHSessionWebProxyAddr], + Reason: c.env[teleport.EnvSSHSessionReason], + DisplayParticipantRequirements: utils.AsBool(c.env[teleport.EnvSSHSessionDisplayParticipantRequirements]), + JoinSessionID: c.env[sshutils.SessionEnvVar], + JoinMode: types.SessionParticipantMode(c.env[teleport.EnvSSHJoinMode]), + ModeratedSessionID: c.env[sftp.EnvModeratedSessionID], + } + + if invitedUsers := c.env[teleport.EnvSSHSessionInvited]; invitedUsers != "" { + if err := json.Unmarshal([]byte(invitedUsers), &sessionParams.Invited); err != nil { + slog.WarnContext(context.Background(), "Failed to parse invited users", "error", err) + } } - return c.Parent().GetEnv(key) + + return sessionParams } // setSession sets the context's session @@ -1166,11 +1163,15 @@ func buildEnvironment(ctx *ServerContext) []string { } // Set some Teleport specific environment variables: SSH_TELEPORT_USER, - // SSH_TELEPORT_HOST_UUID, and SSH_TELEPORT_CLUSTER_NAME. + // SSH_TELEPORT_HOST_UUID, SSH_TELEPORT_CLUSTER_NAME, and SSH_SESSION_WEBPROXY_ADDR. env.AddTrusted(teleport.SSHTeleportHostUUID, ctx.srv.ID()) env.AddTrusted(teleport.SSHTeleportClusterName, ctx.ClusterName) env.AddTrusted(teleport.SSHTeleportUser, ctx.Identity.TeleportUser) + if ctx.GetSessionParams().WebProxyAddr != "" { + env.AddTrusted(teleport.SSHSessionWebProxyAddr, ctx.GetSessionParams().WebProxyAddr) + } + // At the end gather all dynamically defined environment variables ctx.VisitEnv(env.AddUnique) diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 177c6f6aea062..94d7248427a4e 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -562,6 +562,10 @@ func (s *Server) Serve() { config.KeyExchanges = s.kexAlgorithms config.MACs = s.macAlgorithms + // Set the server version to Teleport to enable tracing and other Teleport + // specific features like joining. + config.ServerVersion = sshutils.SSHVersionPrefix + netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(s.Context()) if err != nil { s.logger.ErrorContext(s.Context(), "Unable to fetch cluster config", "error", err) @@ -912,7 +916,7 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha // Create context for this channel. This context will be closed when // forwarding is complete. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, nil) if err != nil { if err := nch.Reject(ssh.ConnectionFailed, "failed to open server context"); err != nil { s.logger.ErrorContext(ctx, "Error rejecting forwarded-tcpip channel", "error", err) @@ -1022,7 +1026,7 @@ func (s *Server) checkTCPIPForwardRequest(ctx context.Context, r *ssh.Request) e // RBAC checks are only necessary when connecting to an agentless node if s.targetServer.IsOpenSSHNode() { - scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext, nil) if err != nil { return err } @@ -1085,7 +1089,7 @@ func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) { func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, req *sshutils.DirectTCPIPReq) { // Create context for this channel. This context will be closed when // forwarding is complete. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, nil) if err != nil { s.logger.ErrorContext(ctx, "Unable to create connection context", "error", err) s.stderrWrite(ctx, ch, "Unable to create connection context.") @@ -1137,12 +1141,22 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r // the remote host. Once the session channel has been established, this function's loop handles // all the "exec", "subsystem" and "shell" requests. func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { + // sessionParams will not be passed by old clients (< v19) or OpenSSH clients. + sessionParams, err := tracessh.ParseSessionParams(nch.ExtraData()) + if err != nil { + s.logger.ErrorContext(ctx, "Failed to parse request data", "data", string(nch.ExtraData()), "error", err) + if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + // Create context for this channel. This context will be closed when the // session request is complete. // There is no need for the forwarding server to initiate disconnects, // based on teleport business logic, because this logic is already // done on the server's terminating side. - scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext) + scx, err := srv.NewServerContext(ctx, s.connectionContext, s, s.identityContext, sessionParams) if err != nil { s.logger.WarnContext(ctx, "Server context setup failed", "error", err) if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("server context setup failed: %v", err)); err != nil { @@ -1163,7 +1177,7 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { // create the remote session channel before accepting the local // channel request; this allows us to propagate the rejection // reason/message in the event the channel is rejected. - remoteSession, err := s.remoteClient.NewSession(ctx) + remoteSession, err := s.remoteClient.NewSessionWithParams(ctx, sessionParams) if err != nil { s.logger.WarnContext(ctx, "Remote session open failed", "error", err) reason, msg := ssh.ConnectionFailed, fmt.Sprintf("remote session open failed: %v", err) diff --git a/lib/srv/git/forward.go b/lib/srv/git/forward.go index 0777192906016..5383dba230fcb 100644 --- a/lib/srv/git/forward.go +++ b/lib/srv/git/forward.go @@ -374,7 +374,7 @@ func (s *ForwardServer) onConnection(ctx context.Context, ccx *sshutils.Connecti // TODO(greedy52) decouple from srv.NewServerContext. We only need // connection monitoring. - serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx) + serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx, nil) if err != nil { return nil, trace.Wrap(err) } @@ -400,11 +400,18 @@ func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionC return } + // sessionParams will not be passed by old clients (< v19) or OpenSSH clients. + sessionParams, err := tracessh.ParseSessionParams(nch.ExtraData()) + if err != nil { + s.reply.RejectWithAcceptError(ctx, nch, err) + return + } + if s.remoteClient == nil { s.reply.RejectWithNewRemoteSessionError(ctx, nch, trace.NotFound("missing remote client")) return } - remoteSession, err := s.remoteClient.NewSession(ctx) + remoteSession, err := s.remoteClient.NewSessionWithParams(ctx, sessionParams) if err != nil { s.reply.RejectWithNewRemoteSessionError(ctx, nch, err) return diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 8f588dc646bb8..8b102bac8cf2e 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1238,7 +1238,7 @@ func (s *Server) getNetworkingProcess(scx *srv.ServerContext) (*networking.Proce // the server connection is closed. func (s *Server) startNetworkingProcess(scx *srv.ServerContext) (*networking.Process, error) { // Create context for the networking process. - nsctx, err := srv.NewServerContext(context.Background(), scx.ConnectionContext, s, scx.Identity) + nsctx, err := srv.NewServerContext(context.Background(), scx.ConnectionContext, s, scx.Identity, nil) if err != nil { return nil, trace.Wrap(err) } @@ -1423,7 +1423,7 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont s.rejectChannel(ctx, nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } - go s.handleSessionRequests(ctx, ccx, identityContext, ch, requests) + go s.handleSessionRequests(ctx, ccx, identityContext, nil, ch, requests) return default: s.rejectChannel(ctx, nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) @@ -1466,6 +1466,15 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont } decr = d } + + // SessionParams are not passed by old clients ( Date: Wed, 12 Nov 2025 17:13:25 -0800 Subject: [PATCH 4/4] Coordinate session ID with Node in Proxy recording mode (#59850) * Generalize PrepareToReceiveSessionID. * Initialize session ID in the connection context and update it from node current-session-id request. * Add session-id-query-v2@goteleport.com request and ensure new session ID is correctly set in proxy recording mode during the channel request. * Replace PrepareToReceiveSessionID with simpler in-place logic. * Don't emit session events or tracker when proxy forwarding to a Teleport Node. * Fix missing session tracker for outdated Teleport Node. * Remove extra major version grace period. * Update integration test. * Cleanup current session ID handling and fix failing tests. * Fix tests. * Address comments. * Restructure currentSessionID handling. * Set newSessionID in test server context. * Fix integration test. * Fix AuditOn integration test. * Address comment on channel close. * Track session on forwarding node. * Fix web shutdown. * Fix nil pointer dereference in test. * Fix test flake. * Fix nil pointer in test. * Fix test flake. * Update lib/srv/ctx.go Co-authored-by: rosstimothy <39066650+rosstimothy@users.noreply.github.com> * Forwarding Node accepts client connection after receiving preparing session ID from node. This way, the forwarder can reject client connections if there is an issue preparing the session ID (impossible join sessions). * Remove check for session.data event which may not be emitted in time for the test. * Address comments. --------- Co-authored-by: rosstimothy <39066650+rosstimothy@users.noreply.github.com> --- constants.go | 15 +++- integration/integration_test.go | 42 ++++++---- lib/srv/ctx.go | 42 +++++++--- lib/srv/exec.go | 8 +- lib/srv/exec_test.go | 1 + lib/srv/forward/sshserver.go | 104 +++++++++++++++++++++++- lib/srv/mock_test.go | 6 ++ lib/srv/regular/sshserver.go | 39 ++++++++- lib/srv/regular/sshserver_test.go | 17 ++-- lib/srv/sess.go | 45 ++++++----- lib/srv/sess_test.go | 129 ++++++++++++------------------ lib/srv/term.go | 4 +- lib/web/sessions.go | 81 ------------------- lib/web/terminal.go | 34 +++++--- 14 files changed, 332 insertions(+), 235 deletions(-) diff --git a/constants.go b/constants.go index 47b4defc1846f..c52e95ad3df36 100644 --- a/constants.go +++ b/constants.go @@ -831,9 +831,22 @@ const ( CurrentSessionIDRequest = "current-session-id@goteleport.com" // SessionIDQueryRequest is sent by clients to ask servers if they - // will generate their own session ID when a new session is created. + // will generate and share their own session ID when a new session + // is started (session and exec/shell channels accepted). + // + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. SessionIDQueryRequest = "session-id-query@goteleport.com" + // SessionIDQueryRequestV2 is sent by clients to ask servers if they + // will generate and share their own session ID when a new session + // channel is accepted, rather than when the shell/exec channel is. + // + // TODO(Joerger): DELETE IN v21.0.0 + // all v19+ servers set the session ID directly after accepting the session channel. + // clients should stop checking in v21, and servers should stop responding to the query in v22. + SessionIDQueryRequestV2 = "session-id-query-v2@goteleport.com" + // ForceTerminateRequest is an SSH request to forcefully terminate a session. ForceTerminateRequest = "x-teleport-force-terminate" diff --git a/integration/integration_test.go b/integration/integration_test.go index 33772ccecdcd8..76e8523ecc8ba 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -522,6 +522,7 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } } + // Test streaming events and recording. capturedStream, sessionEvents := streamSession(ctx, t, site, sessionID) findByType := func(et string) apievents.AuditEvent { @@ -532,19 +533,6 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { } return nil } - // helper that asserts that a session event is also included in the - // general audit log. - requireInAuditLog := func(t *testing.T, sessionEvent apievents.AuditEvent) { - t.Helper() - auditEvents, _, err := site.SearchEvents(ctx, events.SearchEventsRequest{ - To: time.Now(), - EventTypes: []string{sessionEvent.GetType()}, - }) - require.NoError(t, err) - require.True(t, slices.ContainsFunc(auditEvents, func(ae apievents.AuditEvent) bool { - return ae.GetID() == sessionEvent.GetID() - })) - } // there should always be 'session.start' event (and it must be first) first := sessionEvents[0].(*apievents.SessionStart) @@ -552,19 +540,16 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { require.Equal(t, first, start) require.Equal(t, sessionID, start.SessionID) require.NotEmpty(t, start.TerminalSize) - requireInAuditLog(t, start) // there should always be 'session.end' event end := findByType(events.SessionEndEvent).(*apievents.SessionEnd) require.NotNil(t, end) require.Equal(t, sessionID, end.SessionID) - requireInAuditLog(t, end) // there should always be 'session.leave' event leave := findByType(events.SessionLeaveEvent).(*apievents.SessionLeave) require.NotNil(t, leave) require.Equal(t, sessionID, leave.SessionID) - requireInAuditLog(t, leave) // all of them should have a proper time for _, e := range sessionEvents { @@ -575,6 +560,31 @@ func testAuditOn(t *testing.T, suite *integrationTestSuite) { recorded := replaceNewlines(capturedStream) require.Regexp(t, ".*exit.*", recorded) require.Regexp(t, ".*echo hi.*", recorded) + + sessionEvents, _, err = site.SearchEvents(ctx, events.SearchEventsRequest{ + From: time.Time{}, + To: time.Now(), + EventTypes: []string{ + events.SessionStartEvent, + events.SessionLeaveEvent, + events.SessionEndEvent, + }, + }) + require.NoError(t, err) + + // Check that the events found above in the session stream show up in the backend. + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == start.GetID() + }), "expected session events to contain session.start event") + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == end.GetID() + }), "expected session events to contain session.end event") + require.True(t, slices.ContainsFunc(sessionEvents, func(ae apievents.AuditEvent) bool { + return ae.GetID() == leave.GetID() + }), "expected session events to contain session.leave event") + + // Ensure there are no duplicate events, e.g. from proxy recording mode. + require.Len(t, sessionEvents, 3, "%d unexpected duplicate events", len(sessionEvents)-4) }) } } diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index fe7e1187b5767..d43346e097a88 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -52,6 +52,7 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/sshutils/sftp" @@ -319,6 +320,12 @@ type ServerContext struct { // sessionParams are parameters associated with this server session. sessionParams *tracessh.SessionParams + // newSessionID is set if this server context is going to create a new session. + // This field must be set through [ServerContext.SetNewSessionID] for non-join + // sessions as soon as a session channel is accepted in order to inform + // the client of the to-be session ID. + newSessionID rsession.ID + // session holds the active session (if there's an active one). session *session @@ -692,21 +699,25 @@ func (c *ServerContext) GetSessionParams() tracessh.SessionParams { return sessionParams } +// SetNewSessionID sets the ID for a new session in this server context. +func (c *ServerContext) SetNewSessionID(ctx context.Context, sid rsession.ID) { + c.mu.Lock() + defer c.mu.Unlock() + c.newSessionID = sid +} + +// GetNewSessionID gets the ID for a new session in this server context. +func (c *ServerContext) GetNewSessionID() rsession.ID { + c.mu.Lock() + defer c.mu.Unlock() + return c.newSessionID +} + // setSession sets the context's session -func (c *ServerContext) setSession(ctx context.Context, sess *session, ch ssh.Channel) { +func (c *ServerContext) setSession(ctx context.Context, sess *session) { c.mu.Lock() defer c.mu.Unlock() c.session = sess - - // inform the client of the session ID that is being used in a new - // goroutine to reduce latency - go func() { - c.Logger.DebugContext(ctx, "Sending current session ID") - _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sess.ID())) - if err != nil { - c.Logger.DebugContext(ctx, "Failed to send the current session ID", "error", err) - } - }() } // getSession returns the context's session @@ -922,6 +933,15 @@ func (c *ServerContext) reportStats(conn utils.Stater) { serverRX.Add(float64(rxBytes)) } +// ShouldHandleRecording returns whether this server context is responsible for +// recording session events, including session recording, audit events, and session tracking. +func (c *ServerContext) ShouldHandleSessionRecording() bool { + // The only time this server is not responsible for recording the session is when this + // is a Teleport Node with Proxy recording mode turned on, where the forwarding node will + // handle the recording. + return c.srv.Component() != teleport.ComponentNode || !services.IsRecordAtProxy(c.SessionRecordingConfig.GetMode()) +} + func (c *ServerContext) Close() error { // If the underlying connection is holding tracking information, report that // to the audit log at close. diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 8daf6c7c1c31f..2fab90d5a1a1c 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -39,10 +39,8 @@ import ( "github.com/gravitational/teleport" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" - "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -102,10 +100,8 @@ func NewExecRequest(ctx *ServerContext, command string) (Exec, error) { }, nil } - // If this is a registered OpenSSH node or proxy recoding mode is - // enabled, execute the command on a remote host. This is used by - // in-memory forwarding nodes. - if types.IsOpenSSHNodeSubKind(ctx.ServerSubKind) || services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { + // If this is a forwarding node, execute the command on a remote host. + if ctx.srv.Component() == teleport.ComponentForwardingNode { return &remoteExec{ ctx: ctx, command: command, diff --git a/lib/srv/exec_test.go b/lib/srv/exec_test.go index 086a43e7985a2..d5b38453e38ef 100644 --- a/lib/srv/exec_test.go +++ b/lib/srv/exec_test.go @@ -149,6 +149,7 @@ func newExecServerContext(t *testing.T, srv Server) *ServerContext { term: term, emitter: rec, recorder: rec, + scx: scx, } err = scx.SetSSHRequest(&ssh.Request{Type: sshutils.ExecRequest}) require.NoError(t, err) diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 94d7248427a4e..1c0f35fa6c849 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -28,6 +28,7 @@ import ( "net" "os" "strings" + "sync" "time" "github.com/gravitational/trace" @@ -52,6 +53,7 @@ import ( "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshagent" "github.com/gravitational/teleport/lib/sshutils" @@ -987,8 +989,23 @@ func (s *Server) handleGlobalRequest(ctx context.Context, req *ssh.Request) { } // Pass request on unchanged. case teleport.SessionIDQueryRequest: + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. + // Reply true to session ID query requests, we will set new - // session IDs for new sessions + // session IDs for new sessions during the shel/exec channel + // request. + if err := req.Reply(true, nil); err != nil { + s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) + } + return + case teleport.SessionIDQueryRequestV2: + // TODO(Joerger): DELETE IN v21.0.0 + // clients should stop checking in v21, and servers should stop responding to the query in v22. + + // Reply true to session ID query requests, we will set new + // session IDs for new sessions directly after accepting the + // session channel request. if err := req.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) } @@ -1173,6 +1190,47 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { scx.SetAllowFileCopying(true) defer scx.Close() + // If this is a Teleport node server, it should send the session ID + // right after the session channel is accepted. We should reuse this + // session ID and delegate session responsibilities (recordings, audit + // events, and session trackers) to avoid duplicates. + // + // Register handler to receive the current session ID before starting the session. + var newSessionIDFromServer chan string + if s.targetServer.GetSubKind() == types.SubKindTeleportNode { + // Check if the Teleport Node is outdated and won't actually send the session ID. + // + // TODO(Joerger): DELETE IN v20.0.0 + // all v19+ servers set and share the session ID directly after accepting the session channel. + // clients should stop checking in v21, and servers should stop responding to the query in v22. + reply, payload, err := s.remoteClient.SendRequest(ctx, teleport.SessionIDQueryRequestV2, true, nil) + if err != nil { + s.logger.WarnContext(ctx, "Failed to send session ID query request", "error", err) + } else if !reply && payload != nil { + // If the target node replies with a payload, this means that the connection itself has been rejected, + // presumably due to an authz error, and the server is trying to communicate the error with the first + // req/chan received. + s.logger.WarnContext(ctx, "Remote session open failed", "error", err) + if err := nch.Reject(ssh.Prohibited, fmt.Sprintf("remote session open failed: %v", string(payload))); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + + if err == nil && reply { + newSessionIDFromServer = make(chan string, 1) + var receiveSessionIDOnce sync.Once + s.remoteClient.HandleSessionRequest(ctx, teleport.CurrentSessionIDRequest, func(ctx context.Context, req *ssh.Request) { + // Only handle the first request - only one is expected. + receiveSessionIDOnce.Do(func() { + newSessionIDFromServer <- string(req.Payload) + }) + }) + } else { + s.logger.WarnContext(ctx, "Failed to query session ID from target node. Ensure the targeted Teleport Node is upgraded to v19.0.0+ to avoid duplicate events due to mismatched session IDs.") + } + } + // Create a "session" channel on the remote host. Note that we // create the remote session channel before accepting the local // channel request; this allows us to propagate the rejection @@ -1192,6 +1250,38 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { } scx.RemoteSession = remoteSession + if newSessionIDFromServer != nil { + // Wait for the session ID to be reported by the target node. + select { + case sidString := <-newSessionIDFromServer: + sid, err := session.ParseID(sidString) + if err != nil { + s.logger.WarnContext(ctx, "Unable to parse session ID reported by target Teleport Node", "error", err) + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + scx.SetNewSessionID(ctx, *sid) + case <-time.After(10 * time.Second): + s.logger.WarnContext(ctx, "Failed to receive session ID from target node. Ensure the targeted Teleport Node is upgraded to v19.0.0+ to avoid duplicate events due to mismatched session IDs.") + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + case <-ctx.Done(): + if err := nch.Reject(ssh.ConnectionFailed, "target Teleport Node failed to report session ID"); err != nil { + s.logger.WarnContext(ctx, "Failed to reject channel", "channel", nch.ChannelType(), "error", err) + } + return + } + } else { + // The target node is not expected to report session ID, either because it's + // outdated or an agentless node. Continue with a random session ID and ensure + // we create a new session tracker. + scx.SetNewSessionID(ctx, session.NewID()) + } + // Accept the session channel request ch, in, err := nch.Accept() if err != nil { @@ -1202,9 +1292,19 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { return } scx.AddCloser(ch) - ch = scx.TrackActivity(ch) + // inform the client of the session ID that is going to be used in a new + // goroutine to reduce latency. + go func() { + sid := scx.GetNewSessionID() + s.logger.DebugContext(ctx, "Sending current session ID", "sid", sid) + _, err := ch.SendRequest(teleport.CurrentSessionIDRequest, false, []byte(sid)) + if err != nil { + s.logger.DebugContext(ctx, "Failed to send the current session ID", "error", err) + } + }() + s.logger.DebugContext(ctx, "Opening session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID()) defer s.logger.DebugContext(ctx, "Closing session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID()) diff --git a/lib/srv/mock_test.go b/lib/srv/mock_test.go index 78f1980f727ea..a4ea1fb6a98c5 100644 --- a/lib/srv/mock_test.go +++ b/lib/srv/mock_test.go @@ -47,6 +47,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -74,6 +75,7 @@ func newTestServerContext(t *testing.T, srv Server, sessionJoiningRoleSet servic clusterName := "localhost" _, connCtx := sshutils.NewConnectionContext(ctx, nil, &ssh.ServerConn{Conn: sshConn}) scx := &ServerContext{ + newSessionID: rsession.NewID(), Logger: logtest.NewLogger(), ConnectionContext: connCtx, env: make(map[string]string), @@ -93,6 +95,9 @@ func newTestServerContext(t *testing.T, srv Server, sessionJoiningRoleSet servic }, cancelContext: ctx, cancel: cancel, + // If proxy forwarding is being used (proxy recording, agentless), then remote session must be set. + // Otherwise, this field is ignored. + RemoteSession: mockSSHSession(t), } err = scx.SetExecRequest(&localExec{Ctx: scx}) @@ -161,6 +166,7 @@ func newMockServer(t *testing.T) *mockServer { datadir: t.TempDir(), MockRecorderEmitter: &eventstest.MockRecorderEmitter{}, clock: clock, + component: teleport.ComponentNode, } } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 8b102bac8cf2e..b4a6c105d4c69 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -67,6 +67,7 @@ import ( authorizedkeysreporter "github.com/gravitational/teleport/lib/secretsscanner/authorizedkeys" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/sshagent" @@ -1297,8 +1298,23 @@ func (s *Server) HandleRequest(ctx context.Context, ccx *sshutils.ConnectionCont } } case teleport.SessionIDQueryRequest: + // TODO(Joerger): DELETE IN v20.0.0 + // All v17+ servers set the session ID. v19+ clients stop checking. + + // Reply true to session ID query requests, we will set new + // session IDs for new sessions during the shel/exec channel + // request. + if err := r.Reply(true, nil); err != nil { + s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) + } + return + case teleport.SessionIDQueryRequestV2: + // TODO(Joerger): DELETE IN v21.0.0 + // clients should stop checking in v21, and servers should stop responding to the query in v22. + // Reply true to session ID query requests, we will set new - // session IDs for new sessions + // session IDs for new sessions directly after accepting the + // session channel request. if err := r.Reply(true, nil); err != nil { s.logger.WarnContext(ctx, "Failed to reply to session ID query request", "error", err) } @@ -1647,6 +1663,27 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec trackingChan := scx.TrackActivity(ch) + // If we are creating a new session (not joining a session), prepare a new session + // ID and inform the client. + // + // Note: If this is an old client (