diff --git a/integration/integration_test.go b/integration/integration_test.go index d9a44d58c228e..2dbad98d503c6 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -159,6 +159,7 @@ func TestIntegrations(t *testing.T) { t.Run("IP Propagation", suite.bind(testIPPropagation)) t.Run("JumpTrustedClusters", suite.bind(testJumpTrustedClusters)) t.Run("JumpTrustedClustersWithLabels", suite.bind(testJumpTrustedClustersWithLabels)) + t.Run("LeafSessionRecording", suite.bind(testLeafProxySessionRecording)) t.Run("List", suite.bind(testList)) t.Run("MapRoles", suite.bind(testMapRoles)) t.Run("ModeratedSessions", suite.bind(testModeratedSessions)) @@ -1138,6 +1139,152 @@ func testSessionRecordingModes(t *testing.T, suite *integrationTestSuite) { } } +func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) { + tests := []struct { + rootRecordingMode string + leafRecordingMode string + rootHasSess bool + }{ + { + rootRecordingMode: types.RecordAtNode, + leafRecordingMode: types.RecordAtProxy, + rootHasSess: true, + }, + { + rootRecordingMode: types.RecordAtProxy, + leafRecordingMode: types.RecordAtNode, + rootHasSess: true, + }, + { + rootRecordingMode: types.RecordAtNode, + leafRecordingMode: types.RecordAtNode, + rootHasSess: false, + }, + { + rootRecordingMode: types.RecordAtProxy, + leafRecordingMode: types.RecordAtProxy, + rootHasSess: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("root rec mode=%q leaf rec mode=%q", + tt.rootRecordingMode, + tt.leafRecordingMode, + ), func(t *testing.T) { + // Create and start clusters + _, root, leaf := createTrustedClusterPair(t, suite, nil, func(cfg *servicecfg.Config, isRoot bool) { + auditConfig, err := types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{ + AuditSessionsURI: t.TempDir(), + }) + require.NoError(t, err) + + recMode := tt.leafRecordingMode + if isRoot { + recMode = tt.rootRecordingMode + } + recCfg, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ + Mode: recMode, + }) + require.NoError(t, err) + + cfg.Auth.Enabled = true + cfg.Auth.AuditConfig = auditConfig + cfg.Auth.SessionRecordingConfig = recCfg + cfg.Proxy.Enabled = true + cfg.SSH.Enabled = true + }) + + authSrv := root.Process.GetAuthServer() + uploadChan := root.UploadEventsC + if !tt.rootHasSess { + authSrv = leaf.Process.GetAuthServer() + uploadChan = leaf.UploadEventsC + } + + tc, err := root.NewClient(helpers.ClientConfig{ + Login: suite.Me.Username, + Cluster: "leaf-test", + Host: "leaf-zero:0", + }) + require.NoError(t, err) + + ctx := context.Background() + clt, err := tc.ConnectToCluster(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, clt.Close()) + }) + + // Create an interactive SSH session to start session recording + term := NewTerminal(250) + errCh := make(chan error) + + tc.Stdout = term + tc.Stdin = term + + go func() { + nodeClient, err := tc.ConnectToNode( + ctx, + clt, + client.NodeDetails{Addr: "leaf-zero:0", Namespace: tc.Namespace, Cluster: clt.ClusterName()}, + tc.Config.HostLogin, + ) + assert.NoError(t, err) + + errCh <- nodeClient.RunInteractiveShell(ctx, types.SessionPeerMode, nil, nil) + assert.NoError(t, nodeClient.Close()) + }() + + var sessionID string + require.EventuallyWithT(t, func(c *assert.CollectT) { + trackers, err := authSrv.GetActiveSessionTrackers(ctx) + if !assert.NoError(c, err) { + return + } + if !assert.Len(c, trackers, 1) { + return + } + sessionID = trackers[0].GetSessionID() + }, time.Second*5, time.Millisecond*100) + + // Send stuff to the session. + term.Type("echo Hello\n\r") + + // Guarantee the session hasn't stopped after typing. + select { + case <-errCh: + require.Fail(t, "session was closed before") + default: + } + + // Wait for the session to terminate without error. + term.Type("exit\n\r") + require.NoError(t, waitForError(errCh, 5*time.Second)) + + // Wait for the session recording to be uploaded and available + var uploaded bool + timeoutC := time.After(10 * time.Second) + for !uploaded { + select { + case event := <-uploadChan: + if event.SessionID == sessionID { + uploaded = true + } + case <-timeoutC: + require.Fail(t, "timeout waiting for session recording to be uploaded") + } + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + events, err := authSrv.GetSessionEvents(defaults.Namespace, session.ID(sessionID), 0) + assert.NoError(t, err) + assert.NotEmpty(t, events) + }, 5*time.Second, 200*time.Millisecond) + }) + } +} + // TestCustomReverseTunnel tests that the SSH node falls back to configured // proxy address if it cannot connect via the proxy address from the reverse // tunnel discovery query. @@ -7379,9 +7526,9 @@ outer: t.FailNow() } -type serviceCfgOpt func(*servicecfg.Config) +type serviceCfgOpt func(cfg *servicecfg.Config, isRoot bool) -func withProxyRecordingMode(cfg *servicecfg.Config) { +func withProxyRecordingMode(cfg *servicecfg.Config, _ bool) { recCfg := types.DefaultSessionRecordingConfig() recCfg.SetMode(types.RecordAtProxy) cfg.Auth.SessionRecordingConfig = recCfg @@ -7426,13 +7573,22 @@ func createTrustedClusterPair(t *testing.T, suite *integrationTestSuite, extraSe AppLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, KubernetesLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, DatabaseLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + Rules: []types.Rule{ + { + Resources: []string{types.KindSession}, + Verbs: []string{ + types.VerbList, + types.VerbRead, + }, + }, + }, }, }) require.NoError(t, err) root.AddUserWithRole(username, role) leaf.AddUserWithRole(username, role) - makeConfig := func() (*testing.T, []*helpers.InstanceSecrets, *servicecfg.Config) { + makeConfig := func(isRoot bool) (*testing.T, []*helpers.InstanceSecrets, *servicecfg.Config) { tconf := suite.defaultServiceConfig() tconf.Proxy.DisableWebService = false tconf.Proxy.DisableWebInterface = true @@ -7440,7 +7596,7 @@ func createTrustedClusterPair(t *testing.T, suite *integrationTestSuite, extraSe tconf.CachePolicy.MaxRetryPeriod = time.Millisecond * 500 for _, opt := range cfgOpts { - opt(tconf) + opt(tconf, isRoot) } return t, nil, tconf @@ -7450,8 +7606,8 @@ func createTrustedClusterPair(t *testing.T, suite *integrationTestSuite, extraSe lib.SetInsecureDevMode(true) defer lib.SetInsecureDevMode(oldInsecure) - require.NoError(t, root.CreateEx(makeConfig())) - require.NoError(t, leaf.CreateEx(makeConfig())) + require.NoError(t, root.CreateEx(makeConfig(true))) + require.NoError(t, leaf.CreateEx(makeConfig(false))) require.NoError(t, leaf.Process.GetAuthServer().UpsertRole(ctx, role)) // Connect leaf to root. diff --git a/lib/reversetunnel/cache.go b/lib/reversetunnel/cache.go index 6e70b503d8e10..778e7805b15ca 100644 --- a/lib/reversetunnel/cache.go +++ b/lib/reversetunnel/cache.go @@ -31,7 +31,6 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/sshca" ) type certificateCache struct { @@ -39,12 +38,11 @@ type certificateCache struct { cache *ttlmap.TTLMap authClient auth.ClientI - keygen sshca.Authority } // newHostCertificateCache creates a shared host certificate cache that is // used by the forwarding server. -func newHostCertificateCache(keygen sshca.Authority, authClient auth.ClientI) (*certificateCache, error) { +func newHostCertificateCache(authClient auth.ClientI) (*certificateCache, error) { native.PrecomputeKeys() // ensure native package is set to precompute keys cache, err := ttlmap.New(defaults.HostCertCacheSize) if err != nil { @@ -52,7 +50,6 @@ func newHostCertificateCache(keygen sshca.Authority, authClient auth.ClientI) (* } return &certificateCache{ - keygen: keygen, cache: cache, authClient: authClient, }, nil diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 84ae07cc0c51b..6f9e9d251997f 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -119,7 +119,7 @@ func newLocalSite(srv *server, domainName string, authServers []string, opts ... // certificate cache is created in each site (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate // is signed by the correct certificate authority. - certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) + certificateCache, err := newHostCertificateCache(srv.localAuthClient) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 1c6ed62973383..88b2aed0a7c5e 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -764,21 +764,32 @@ func (s *remoteSite) DialAuthServer(params reversetunnelclient.DialParams) (net. // located in a remote connected site, the connection goes through the // reverse proxy tunnel. func (s *remoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { - recConfig, err := s.localAccessPoint.GetSessionRecordingConfig(s.ctx) + localRecCfg, err := s.localAccessPoint.GetSessionRecordingConfig(s.ctx) if err != nil { return nil, trace.Wrap(err) } - if err := checkNodeAndRecConfig(params, recConfig); err != nil { + if err := checkNodeAndRecConfig(params, localRecCfg); err != nil { return nil, trace.Wrap(err) } - // If the proxy is in recording mode and a SSH connection is being - // requested or the target server is a registered OpenSSH node, build - // an in-memory forwarding server. - if shouldDialAndForward(params, recConfig) { + if shouldDialAndForward(params, localRecCfg) { return s.dialAndForward(params) } + if params.ConnType == types.NodeTunnel { + // If the remote cluster is recording at the proxy we need to respect + // that and forward and record the session. We will be connecting + // to the node without connecting through the remote proxy, so the + // session won't have a chance to get recorded at the remote proxy. + remoteRecCfg, err := s.remoteAccessPoint.GetSessionRecordingConfig(s.ctx) + if err != nil { + return nil, trace.Wrap(err) + } + if services.IsRecordAtProxy(remoteRecCfg.GetMode()) { + return s.dialAndForward(params) + } + } + // Attempt to perform a direct TCP dial. return s.DialTCP(params) } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 5875440cd322a..d9a773b878724 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -1167,7 +1167,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, // certificate cache is created in each site (instead of creating it in // reversetunnel.server and passing it along) so that the host certificate // is signed by the correct certificate authority. - certificateCache, err := newHostCertificateCache(srv.Config.KeyGen, srv.localAuthClient) + certificateCache, err := newHostCertificateCache(srv.localAuthClient) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 2976ef20544a1..79d4401a351fa 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -1377,6 +1377,7 @@ func newRecorder(s *session, ctx *ServerContext) (events.SessionPreparerRecorder // Nodes discard events in cases when proxies are already recording them. if s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { + s.log.WithField("session_id", s.ID()).Trace("session will be recorded at proxy") return events.WithNoOpPreparer(events.NewDiscardRecorder()), nil }