diff --git a/integration/teleterm_test.go b/integration/teleterm_test.go index a2b71b17cc957..1dce72af702e8 100644 --- a/integration/teleterm_test.go +++ b/integration/teleterm_test.go @@ -34,6 +34,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -116,6 +117,12 @@ func TestTeleterm(t *testing.T) { t.Parallel() testDeleteConnectMyComputerNode(t, pack) }) + + t.Run("TestClientCache", func(t *testing.T) { + t.Parallel() + + testClientCache(t, pack, creds) + }) } func testAddingRootCluster(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { @@ -335,6 +342,79 @@ func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *help ) } +func testClientCache(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { + ctx := context.Background() + + tc := mustLogin(t, pack.Root.User.GetName(), pack, creds) + + storageFakeClock := clockwork.NewFakeClockAt(time.Now()) + + storage, err := clusters.NewStorage(clusters.Config{ + Dir: tc.KeysDir, + Clock: storageFakeClock, + InsecureSkipVerify: tc.InsecureSkipVerify, + }) + require.NoError(t, err) + + cluster, _, err := storage.Add(ctx, tc.WebProxyAddr) + require.NoError(t, err) + + daemonService, err := daemon.New(daemon.Config{ + Storage: storage, + CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { + daemonService.Stop() + }) + + // Check if parallel calls trying to get a client will return the same one. + eg, egCtx := errgroup.WithContext(ctx) + blocker := make(chan struct{}) + const concurrentCalls = 5 + concurrentCallsForClient := make([]*client.ProxyClient, concurrentCalls) + for i := range concurrentCallsForClient { + client := &concurrentCallsForClient[i] + eg.Go(func() error { + <-blocker + c, err := daemonService.GetCachedClient(egCtx, cluster.URI) + *client = c + return err + }) + } + // unblock the operation which is still in progress + close(blocker) + require.NoError(t, eg.Wait()) + require.Subset(t, concurrentCallsForClient[:1], concurrentCallsForClient[1:]) + + // Since we have a client in the cache, it should be returned. + secondCallForClient, err := daemonService.GetCachedClient(ctx, cluster.URI) + require.NoError(t, err) + require.Equal(t, concurrentCallsForClient[0], secondCallForClient) + + // Let's remove the client from the cache. + // The call to GetCachedClient will + // connect to proxy and return a new client. + err = daemonService.ClearCachedClientsForRoot(cluster.URI) + require.NoError(t, err) + thirdCallForClient, err := daemonService.GetCachedClient(ctx, cluster.URI) + require.NoError(t, err) + require.NotEqual(t, secondCallForClient, thirdCallForClient) + + // After closing the client (from our or a remote side) + // it will be removed from the cache. + // The call to GetCachedClient will connect to proxy and return a new client. + err = thirdCallForClient.Close() + require.NoError(t, err) + fourthCallForClient, err := daemonService.GetCachedClient(ctx, cluster.URI) + require.NoError(t, err) + require.NotEqual(t, thirdCallForClient, fourthCallForClient) +} + func testCreateConnectMyComputerRole(t *testing.T, pack *dbhelpers.DatabasePack) { systemUser, err := user.Current() require.NoError(t, err) diff --git a/lib/teleterm/apiserver/handler/handler_apps.go b/lib/teleterm/apiserver/handler/handler_apps.go index 6f9f7e2ec6d2b..2789aba8d8899 100644 --- a/lib/teleterm/apiserver/handler/handler_apps.go +++ b/lib/teleterm/apiserver/handler/handler_apps.go @@ -34,7 +34,12 @@ func (s *Handler) GetApps(ctx context.Context, req *api.GetAppsRequest) (*api.Ge return nil, trace.Wrap(err) } - resp, err := cluster.GetApps(ctx, req) + proxyClient, err := s.DaemonService.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + resp, err := cluster.GetApps(ctx, proxyClient.CurrentCluster(), req) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/apiserver/handler/handler_auth.go b/lib/teleterm/apiserver/handler/handler_auth.go index 2493019774d5d..5ef1283f1e3e6 100644 --- a/lib/teleterm/apiserver/handler/handler_auth.go +++ b/lib/teleterm/apiserver/handler/handler_auth.go @@ -37,6 +37,10 @@ func (s *Handler) Login(ctx context.Context, req *api.LoginRequest) (*api.EmptyR // added by daemon.Service.ResolveClusterURI. clusterClient.MFAPromptConstructor = nil + if err = s.DaemonService.ClearCachedClientsForRoot(cluster.URI); err != nil { + return nil, trace.Wrap(err) + } + if req.Params == nil { return nil, trace.BadParameter("missing login parameters") } @@ -84,6 +88,10 @@ func (s *Handler) LoginPasswordless(stream api.TerminalService_LoginPasswordless // daemon.Service.ResolveClusterURI. clusterClient.MFAPromptConstructor = nil + if err := s.DaemonService.ClearCachedClientsForRoot(cluster.URI); err != nil { + return trace.Wrap(err) + } + // Start the prompt flow. if err := cluster.PasswordlessLogin(stream.Context(), stream); err != nil { return trace.Wrap(err) diff --git a/lib/teleterm/apiserver/handler/handler_databases.go b/lib/teleterm/apiserver/handler/handler_databases.go index cc48d585f72cf..681181b17395b 100644 --- a/lib/teleterm/apiserver/handler/handler_databases.go +++ b/lib/teleterm/apiserver/handler/handler_databases.go @@ -28,14 +28,19 @@ import ( "github.com/gravitational/teleport/lib/teleterm/clusters" ) -// GetDatabases gets databses with filters and returns paginated results +// GetDatabases gets databases with filters and returns paginated results func (s *Handler) GetDatabases(ctx context.Context, req *api.GetDatabasesRequest) (*api.GetDatabasesResponse, error) { cluster, _, err := s.DaemonService.ResolveCluster(req.ClusterUri) if err != nil { return nil, trace.Wrap(err) } - resp, err := cluster.GetDatabases(ctx, req) + proxyClient, err := s.DaemonService.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + resp, err := cluster.GetDatabases(ctx, proxyClient.CurrentCluster(), req) if err != nil { return nil, trace.Wrap(err) } @@ -63,7 +68,12 @@ func (s *Handler) ListDatabaseUsers(ctx context.Context, req *api.ListDatabaseUs return nil, trace.Wrap(err) } - dbUsers, err := cluster.GetAllowedDatabaseUsers(ctx, req.DbUri) + proxyClient, err := s.DaemonService.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + dbUsers, err := cluster.GetAllowedDatabaseUsers(ctx, proxyClient.CurrentCluster(), req.DbUri) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/apiserver/middleware.go b/lib/teleterm/apiserver/middleware.go index d9123fa14f9de..ac9004abfa1aa 100644 --- a/lib/teleterm/apiserver/middleware.go +++ b/lib/teleterm/apiserver/middleware.go @@ -20,11 +20,14 @@ package apiserver import ( "context" + "errors" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" "github.com/sirupsen/logrus" "google.golang.org/grpc" + + "github.com/gravitational/teleport/api/client" ) // withErrorHandling is gRPC middleware that maps internal errors to proper gRPC error codes @@ -38,6 +41,14 @@ func withErrorHandling(log logrus.FieldLogger) grpc.UnaryServerInterceptor { resp, err := handler(ctx, req) if err != nil { log.WithError(err).Error("Request failed.") + // A stop gap solution that allows us to show a relogin modal when we + // receive an error from the server saying that the cert is expired. + // Read more: https://github.com/gravitational/teleport/pull/38202#discussion_r1497181659 + // TODO(gzdunek): fix when addressing https://github.com/gravitational/teleport/issues/32550 + if errors.Is(err, client.ErrClientCredentialsHaveExpired) { + return resp, trail.ToGRPC(err) + } + // do not return a full error stack on access denied errors if trace.IsAccessDenied(err) { return resp, trail.ToGRPC(trace.AccessDenied("access denied")) diff --git a/lib/teleterm/clusters/cluster.go b/lib/teleterm/clusters/cluster.go index 7141782cbfbb6..b2e53606a9375 100644 --- a/lib/teleterm/clusters/cluster.go +++ b/lib/teleterm/clusters/cluster.go @@ -83,7 +83,7 @@ func (c *Cluster) Connected() bool { // GetWithDetails makes requests to the auth server to return details of the current // Cluster that cannot be found on the disk only, including details about the user // and enabled enterprise features. This method requires a valid cert. -func (c *Cluster) GetWithDetails(ctx context.Context) (*ClusterWithDetails, error) { +func (c *Cluster) GetWithDetails(ctx context.Context, authClient auth.ClientI) (*ClusterWithDetails, error) { var ( authPingResponse proto.PingResponse caps *types.AccessCapabilities @@ -97,20 +97,8 @@ func (c *Cluster) GetWithDetails(ctx context.Context) (*ClusterWithDetails, erro return nil, trace.Wrap(err) } + //TODO(gzdunek): These calls should be done in parallel. err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - authPingResponse, err = authClient.Ping(ctx) if err != nil { return trace.Wrap(err) @@ -218,12 +206,10 @@ func (c *Cluster) GetRoles(ctx context.Context) ([]*types.Role, error) { } // GetRequestableRoles returns the requestable roles for the currently logged-in user -func (c *Cluster) GetRequestableRoles(ctx context.Context, req *api.GetRequestableRolesRequest) (*types.AccessCapabilities, error) { +func (c *Cluster) GetRequestableRoles(ctx context.Context, req *api.GetRequestableRolesRequest, authClient auth.ClientI) (*types.AccessCapabilities, error) { var ( - authClient auth.ClientI - proxyClient *client.ProxyClient - err error - response *types.AccessCapabilities + err error + response *types.AccessCapabilities ) resourceIds := make([]types.ResourceID, 0, len(req.GetResourceIds())) @@ -237,19 +223,6 @@ func (c *Cluster) GetRequestableRoles(ctx context.Context, req *api.GetRequestab } err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - response, err = authClient.GetAccessCapabilities(ctx, types.AccessCapabilitiesRequest{ ResourceIDs: resourceIds, RequestableRoles: true, diff --git a/lib/teleterm/clusters/cluster_access_requests.go b/lib/teleterm/clusters/cluster_access_requests.go index 42ba8864ae7c7..e34bec4f126a6 100644 --- a/lib/teleterm/clusters/cluster_access_requests.go +++ b/lib/teleterm/clusters/cluster_access_requests.go @@ -46,24 +46,15 @@ type AccessRequest struct { } // GetAccessRequest returns a specific access request by ID and includes resource details -func (c *Cluster) GetAccessRequest(ctx context.Context, req types.AccessRequestFilter) (*AccessRequest, error) { +func (c *Cluster) GetAccessRequest(ctx context.Context, rootAuthClient auth.ClientI, req types.AccessRequestFilter) (*AccessRequest, error) { var ( request types.AccessRequest resourceDetails map[string]ResourceDetails - proxyClient *client.ProxyClient - authClient auth.ClientI err error ) err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - requests, err := proxyClient.GetAccessRequests(ctx, req) + requests, err := rootAuthClient.GetAccessRequests(ctx, req) if err != nil { return trace.Wrap(err) } @@ -75,13 +66,7 @@ func (c *Cluster) GetAccessRequest(ctx context.Context, req types.AccessRequestF } request = requests[0] - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - - resourceDetails, err = getResourceDetails(ctx, request, authClient) + resourceDetails, err = getResourceDetails(ctx, request, rootAuthClient) return err }) @@ -97,13 +82,13 @@ func (c *Cluster) GetAccessRequest(ctx context.Context, req types.AccessRequestF } // Returns all access requests available to the user. -func (c *Cluster) GetAccessRequests(ctx context.Context, req types.AccessRequestFilter) ([]AccessRequest, error) { +func (c *Cluster) GetAccessRequests(ctx context.Context, rootAuthClient auth.ClientI, req types.AccessRequestFilter) ([]AccessRequest, error) { var ( requests []types.AccessRequest err error ) err = AddMetadataToRetryableError(ctx, func() error { - requests, err = c.clusterClient.GetAccessRequests(ctx, req) + requests, err = rootAuthClient.GetAccessRequests(ctx, req) return err }) if err != nil { @@ -122,7 +107,7 @@ func (c *Cluster) GetAccessRequests(ctx context.Context, req types.AccessRequest } // Creates an access request. -func (c *Cluster) CreateAccessRequest(ctx context.Context, req *api.CreateAccessRequestRequest) (*AccessRequest, error) { +func (c *Cluster) CreateAccessRequest(ctx context.Context, rootAuthClient auth.ClientI, req *api.CreateAccessRequestRequest) (*AccessRequest, error) { var ( err error request types.AccessRequest @@ -153,7 +138,7 @@ func (c *Cluster) CreateAccessRequest(ctx context.Context, req *api.CreateAccess var reqOut types.AccessRequest err = AddMetadataToRetryableError(ctx, func() error { - reqOut, err = c.clusterClient.CreateAccessRequestV2(ctx, request) + reqOut, err = rootAuthClient.CreateAccessRequestV2(ctx, request) return trace.Wrap(err) }) if err != nil { @@ -166,11 +151,9 @@ func (c *Cluster) CreateAccessRequest(ctx context.Context, req *api.CreateAccess }, nil } -func (c *Cluster) ReviewAccessRequest(ctx context.Context, req *api.ReviewAccessRequestRequest) (*AccessRequest, error) { +func (c *Cluster) ReviewAccessRequest(ctx context.Context, rootAuthClient auth.ClientI, req *api.ReviewAccessRequestRequest) (*AccessRequest, error) { var ( err error - authClient auth.ClientI - proxyClient *client.ProxyClient updatedRequest types.AccessRequest ) @@ -180,19 +163,6 @@ func (c *Cluster) ReviewAccessRequest(ctx context.Context, req *api.ReviewAccess } err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - reviewSubmission := types.AccessReviewSubmission{ RequestID: req.AccessRequestId, Review: types.AccessReview{ @@ -203,7 +173,7 @@ func (c *Cluster) ReviewAccessRequest(ctx context.Context, req *api.ReviewAccess }, } - updatedRequest, err = authClient.SubmitAccessReview(ctx, reviewSubmission) + updatedRequest, err = rootAuthClient.SubmitAccessReview(ctx, reviewSubmission) return trace.Wrap(err) }) @@ -217,40 +187,15 @@ func (c *Cluster) ReviewAccessRequest(ctx context.Context, req *api.ReviewAccess }, nil } -func (c *Cluster) DeleteAccessRequest(ctx context.Context, req *api.DeleteAccessRequestRequest) error { - var ( - err error - authClient auth.ClientI - proxyClient *client.ProxyClient - ) - - err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - - return authClient.DeleteAccessRequest(ctx, req.AccessRequestId) +func (c *Cluster) DeleteAccessRequest(ctx context.Context, rootAuthClient auth.ClientI, req *api.DeleteAccessRequestRequest) error { + err := AddMetadataToRetryableError(ctx, func() error { + return rootAuthClient.DeleteAccessRequest(ctx, req.AccessRequestId) }) - if err != nil { - return trace.Wrap(err) - } - - return nil + return trace.Wrap(err) } -func (c *Cluster) AssumeRole(ctx context.Context, req *api.AssumeRoleRequest) error { - var err error - - err = AddMetadataToRetryableError(ctx, func() error { +func (c *Cluster) AssumeRole(ctx context.Context, rootProxyClient *client.ProxyClient, req *api.AssumeRoleRequest) error { + err := AddMetadataToRetryableError(ctx, func() error { params := client.ReissueParams{ AccessRequests: req.AccessRequestIds, DropAccessRequests: req.DropRequestIds, @@ -265,26 +210,22 @@ func (c *Cluster) AssumeRole(ctx context.Context, req *api.AssumeRoleRequest) er } // When assuming a role, we want to drop all cached certs otherwise // tsh will continue to use the old certs. - return c.clusterClient.ReissueUserCerts(ctx, client.CertCacheDrop, params) + return rootProxyClient.ReissueUserCerts(ctx, client.CertCacheDrop, params) }) if err != nil { return trace.Wrap(err) } err = c.clusterClient.SaveProfile(true) - if err != nil { - return trace.Wrap(err) - } - - return nil + return trace.Wrap(err) } -func getResourceDetails(ctx context.Context, req types.AccessRequest, clt auth.ClientI) (map[string]ResourceDetails, error) { +func getResourceDetails(ctx context.Context, req types.AccessRequest, rootAuthClient auth.ClientI) (map[string]ResourceDetails, error) { resourceIDsByCluster := accessrequest.GetResourceIDsByCluster(req) resourceDetails := make(map[string]ResourceDetails) for clusterName, resourceIDs := range resourceIDsByCluster { - details, err := accessrequest.GetResourceDetails(ctx, clusterName, clt, resourceIDs) + details, err := accessrequest.GetResourceDetails(ctx, clusterName, rootAuthClient, resourceIDs) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/clusters/cluster_apps.go b/lib/teleterm/clusters/cluster_apps.go index 26767333e80db..888811b428c86 100644 --- a/lib/teleterm/clusters/cluster_apps.go +++ b/lib/teleterm/clusters/cluster_apps.go @@ -65,12 +65,10 @@ type AppOrSAMLIdPServiceProvider struct { } // GetApps returns a paginated apps list -func (c *Cluster) GetApps(ctx context.Context, r *api.GetAppsRequest) (*GetAppsResponse, error) { +func (c *Cluster) GetApps(ctx context.Context, authClient auth.ClientI, r *api.GetAppsRequest) (*GetAppsResponse, error) { var ( - page apiclient.ResourcePage[types.AppServerOrSAMLIdPServiceProvider] - authClient auth.ClientI - proxyClient *client.ProxyClient - err error + page apiclient.ResourcePage[types.AppServerOrSAMLIdPServiceProvider] + err error ) req := &proto.ListResourcesRequest{ @@ -85,25 +83,8 @@ func (c *Cluster) GetApps(ctx context.Context, r *api.GetAppsRequest) (*GetAppsR } err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - page, err = apiclient.GetResourcePage[types.AppServerOrSAMLIdPServiceProvider](ctx, authClient, req) - if err != nil { - return trace.Wrap(err) - } - - return nil + return trace.Wrap(err) }) if err != nil { return nil, trace.Wrap(err) @@ -143,10 +124,10 @@ type GetAppsResponse struct { TotalCount int } -func (c *Cluster) getApp(ctx context.Context, appName string) (types.Application, error) { +func (c *Cluster) getApp(ctx context.Context, authClient auth.ClientI, appName string) (types.Application, error) { var app types.Application err := AddMetadataToRetryableError(ctx, func() error { - apps, err := c.clusterClient.ListApps(ctx, &proto.ListResourcesRequest{ + apps, err := apiclient.GetAllResources[types.AppServer](ctx, authClient, &proto.ListResourcesRequest{ Namespace: c.clusterClient.Namespace, ResourceType: types.KindAppServer, PredicateExpression: fmt.Sprintf(`name == "%s"`, appName), @@ -159,7 +140,7 @@ func (c *Cluster) getApp(ctx context.Context, appName string) (types.Application return trace.NotFound("app %q not found", appName) } - app = apps[0] + app = apps[0].GetApp() return nil }) @@ -167,12 +148,12 @@ func (c *Cluster) getApp(ctx context.Context, appName string) (types.Application } // reissueAppCert issue new certificates for the app and saves them to disk. -func (c *Cluster) reissueAppCert(ctx context.Context, app types.Application) (tls.Certificate, error) { +func (c *Cluster) reissueAppCert(ctx context.Context, proxyClient *client.ProxyClient, app types.Application) (tls.Certificate, error) { if app.IsAWSConsole() || app.IsGCP() || app.IsAzureCloud() { return tls.Certificate{}, trace.BadParameter("cloud applications are not supported") } // Refresh the certs to account for clusterClient.SiteName pointing at a leaf cluster. - err := c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ + err := proxyClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ RouteToCluster: c.clusterClient.SiteName, AccessRequests: c.status.ActiveRequests.AccessRequests, }) @@ -180,13 +161,6 @@ func (c *Cluster) reissueAppCert(ctx context.Context, app types.Application) (tl return tls.Certificate{}, trace.Wrap(err) } - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return tls.Certificate{}, trace.Wrap(err) - } - defer proxyClient.Close() - request := types.CreateAppSessionRequest{ Username: c.status.Username, PublicAddr: app.GetPublicAddr(), diff --git a/lib/teleterm/clusters/cluster_databases.go b/lib/teleterm/clusters/cluster_databases.go index 1c467fbadde1f..51dfa9a107773 100644 --- a/lib/teleterm/clusters/cluster_databases.go +++ b/lib/teleterm/clusters/cluster_databases.go @@ -20,6 +20,7 @@ package clusters import ( "context" + "fmt" "github.com/gravitational/trace" @@ -46,67 +47,40 @@ type Database struct { } // GetDatabase returns a database -func (c *Cluster) GetDatabase(ctx context.Context, dbURI uri.ResourceURI) (*Database, error) { - // TODO(ravicious): Fetch a single db instead of filtering the response from GetDatabases. - // https://github.com/gravitational/teleport/pull/14690#discussion_r927720600 - dbs, err := c.getAllDatabases(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - for _, db := range dbs { - if db.URI == dbURI { - return &db, nil - } - } - - return nil, trace.NotFound("database is not found: %v", dbURI) -} - -// GetDatabases returns databases -// TODO(ravicious): Remove this method in favor of fetching a single database in GetDatabase. -// https://github.com/gravitational/teleport/pull/14690#discussion_r927720600 -func (c *Cluster) getAllDatabases(ctx context.Context) ([]Database, error) { - var dbs []types.Database +func (c *Cluster) GetDatabase(ctx context.Context, authClient auth.ClientI, dbURI uri.ResourceURI) (*Database, error) { + var database types.Database + dbName := dbURI.GetDbName() err := AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) + databases, err := apiclient.GetAllResources[types.DatabaseServer](ctx, authClient, &proto.ListResourcesRequest{ + Namespace: c.clusterClient.Namespace, + ResourceType: types.KindDatabaseServer, + PredicateExpression: fmt.Sprintf(`name == "%s"`, dbName), + }) if err != nil { return trace.Wrap(err) } - defer proxyClient.Close() - dbs, err = proxyClient.FindDatabasesByFilters(ctx, proto.ListResourcesRequest{ - Namespace: defaults.Namespace, - ResourceType: types.KindDatabaseServer, - }) - if err != nil { - return trace.Wrap(err) + if len(databases) == 0 { + return trace.NotFound("database %q not found", dbName) } + database = databases[0].GetDatabase() return nil }) if err != nil { return nil, trace.Wrap(err) } - var responseDbs []Database - for _, db := range dbs { - responseDbs = append(responseDbs, Database{ - URI: c.URI.AppendDB(db.GetName()), - Database: db, - }) - } - - return responseDbs, nil + return &Database{ + URI: c.URI.AppendDB(database.GetName()), + Database: database, + }, err } -func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) (*GetDatabasesResponse, error) { +func (c *Cluster) GetDatabases(ctx context.Context, authClient auth.ClientI, r *api.GetDatabasesRequest) (*GetDatabasesResponse, error) { var ( - page apiclient.ResourcePage[types.DatabaseServer] - authClient auth.ClientI - proxyClient *client.ProxyClient - err error + page apiclient.ResourcePage[types.DatabaseServer] + err error ) req := &proto.ListResourcesRequest{ @@ -121,19 +95,6 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) } err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - page, err = apiclient.GetResourcePage[types.DatabaseServer](ctx, authClient, req) return trace.Wrap(err) }) @@ -156,7 +117,7 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest) } // reissueDBCerts issues new certificates for specific DB access and saves them to disk. -func (c *Cluster) reissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error { +func (c *Cluster) reissueDBCerts(ctx context.Context, proxyClient *client.ProxyClient, routeToDatabase tlsca.RouteToDatabase) error { // When generating certificate for MongoDB access, database username must // be encoded into it. This is required to be able to tell which database // user to authenticate the connection as. @@ -165,7 +126,7 @@ func (c *Cluster) reissueDBCerts(ctx context.Context, routeToDatabase tlsca.Rout } // Refresh the certs to account for clusterClient.SiteName pointing at a leaf cluster. - err := c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ + err := proxyClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ RouteToCluster: c.clusterClient.SiteName, AccessRequests: c.status.ActiveRequests.AccessRequests, }) @@ -174,7 +135,7 @@ func (c *Cluster) reissueDBCerts(ctx context.Context, routeToDatabase tlsca.Rout } // Fetch the certs for the database. - err = c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ + err = proxyClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ RouteToCluster: c.clusterClient.SiteName, RouteToDatabase: proto.RouteToDatabase{ ServiceName: routeToDatabase.ServiceName, @@ -197,41 +158,18 @@ func (c *Cluster) reissueDBCerts(ctx context.Context, routeToDatabase tlsca.Rout } // GetAllowedDatabaseUsers returns allowed users for the given database based on the role set. -func (c *Cluster) GetAllowedDatabaseUsers(ctx context.Context, dbURI string) ([]string, error) { - var authClient auth.ClientI - var proxyClient *client.ProxyClient - +func (c *Cluster) GetAllowedDatabaseUsers(ctx context.Context, authClient auth.ClientI, dbURI string) ([]string, error) { dbResourceURI, err := uri.ParseDBURI(dbURI) if err != nil { return nil, trace.Wrap(err) } - err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - - return nil - }) - if err != nil { - return nil, trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return nil, trace.Wrap(err) - } - defer authClient.Close() - accessChecker, err := services.NewAccessCheckerForRemoteCluster(ctx, c.status.AccessInfo(), c.status.Cluster, authClient) if err != nil { return nil, trace.Wrap(err) } - db, err := c.GetDatabase(ctx, dbResourceURI) + db, err := c.GetDatabase(ctx, authClient, dbResourceURI) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/clusters/cluster_gateways.go b/lib/teleterm/clusters/cluster_gateways.go index fa9faaa330ef7..04a3d707fdf39 100644 --- a/lib/teleterm/clusters/cluster_gateways.go +++ b/lib/teleterm/clusters/cluster_gateways.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/mfa" + "github.com/gravitational/teleport/lib/client" libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/gateway" @@ -45,6 +46,7 @@ type CreateGatewayParams struct { OnExpiredCert gateway.OnExpiredCertFunc KubeconfigsDir string MFAPromptConstructor func(cfg *libmfa.PromptConfig) mfa.Prompt + ProxyClient *client.ProxyClient } // CreateGateway creates a gateway @@ -70,7 +72,7 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams) } func (c *Cluster) createDBGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { - db, err := c.GetDatabase(ctx, params.TargetURI) + db, err := c.GetDatabase(ctx, params.ProxyClient.CurrentCluster(), params.TargetURI) if err != nil { return nil, trace.Wrap(err) } @@ -82,7 +84,7 @@ func (c *Cluster) createDBGateway(ctx context.Context, params CreateGatewayParam } err = AddMetadataToRetryableError(ctx, func() error { - return trace.Wrap(c.reissueDBCerts(ctx, routeToDatabase)) + return trace.Wrap(c.reissueDBCerts(ctx, params.ProxyClient, routeToDatabase)) }) if err != nil { return nil, trace.Wrap(err) @@ -117,7 +119,7 @@ func (c *Cluster) createKubeGateway(ctx context.Context, params CreateGatewayPar kube := params.TargetURI.GetKubeName() // Check if this kube exists and the user has access to it. - if _, err := c.getKube(ctx, kube); err != nil { + if _, err := c.getKube(ctx, params.ProxyClient.CurrentCluster(), kube); err != nil { return nil, trace.Wrap(err) } @@ -125,7 +127,7 @@ func (c *Cluster) createKubeGateway(ctx context.Context, params CreateGatewayPar var err error if err := AddMetadataToRetryableError(ctx, func() error { - cert, err = c.reissueKubeCert(ctx, kube) + cert, err = c.reissueKubeCert(ctx, params.ProxyClient, kube) return trace.Wrap(err) }); err != nil { return nil, trace.Wrap(err) @@ -155,7 +157,7 @@ func (c *Cluster) createKubeGateway(ctx context.Context, params CreateGatewayPar func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { appName := params.TargetURI.GetAppName() - app, err := c.getApp(ctx, appName) + app, err := c.getApp(ctx, params.ProxyClient.CurrentCluster(), appName) if err != nil { return nil, trace.Wrap(err) } @@ -163,7 +165,7 @@ func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayPara var cert tls.Certificate if err := AddMetadataToRetryableError(ctx, func() error { - cert, err = c.reissueAppCert(ctx, app) + cert, err = c.reissueAppCert(ctx, params.ProxyClient, app) return trace.Wrap(err) }); err != nil { return nil, trace.Wrap(err) @@ -194,14 +196,14 @@ func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayPara // At the moment, kube gateways reload their certs in memory while db gateways use the old approach // of saving a cert to disk and only then loading it to memory. // TODO(ravicious): Refactor db gateways to reload cert in memory and support MFA. -func (c *Cluster) ReissueGatewayCerts(ctx context.Context, g gateway.Gateway) (tls.Certificate, error) { +func (c *Cluster) ReissueGatewayCerts(ctx context.Context, proxyClient *client.ProxyClient, g gateway.Gateway) (tls.Certificate, error) { switch { case g.TargetURI().IsDB(): db, err := gateway.AsDatabase(g) if err != nil { return tls.Certificate{}, trace.Wrap(err) } - err = c.reissueDBCerts(ctx, db.RouteToDatabase()) + err = c.reissueDBCerts(ctx, proxyClient, db.RouteToDatabase()) if err != nil { return tls.Certificate{}, trace.Wrap(err) } @@ -218,18 +220,18 @@ func (c *Cluster) ReissueGatewayCerts(ctx context.Context, g gateway.Gateway) (t // from ReissueGatewayCerts, at least not until we add support for MFA to them. return tls.Certificate{}, nil case g.TargetURI().IsKube(): - cert, err := c.reissueKubeCert(ctx, g.TargetName()) + cert, err := c.reissueKubeCert(ctx, proxyClient, g.TargetName()) return cert, trace.Wrap(err) case g.TargetURI().IsApp(): appName := g.TargetURI().GetAppName() - app, err := c.getApp(ctx, appName) + app, err := c.getApp(ctx, proxyClient.CurrentCluster(), appName) if err != nil { return tls.Certificate{}, trace.Wrap(err) } // The cert is saved and then loaded from disk, then returned from this function and finally set // on LocalProxy by the middleware. - cert, err := c.reissueAppCert(ctx, app) + cert, err := c.reissueAppCert(ctx, proxyClient, app) return cert, trace.Wrap(err) default: return tls.Certificate{}, trace.NotImplemented("ReissueGatewayCerts does not support this gateway kind %v", g.TargetURI().String()) diff --git a/lib/teleterm/clusters/cluster_headless.go b/lib/teleterm/clusters/cluster_headless.go index 80c322813e7bb..5bca4fc0e65ba 100644 --- a/lib/teleterm/clusters/cluster_headless.go +++ b/lib/teleterm/clusters/cluster_headless.go @@ -26,52 +26,25 @@ import ( "github.com/gravitational/teleport/api/client/proto" mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" ) // WatchPendingHeadlessAuthentications watches the backend for pending headless authentication requests for the user. -func (c *Cluster) WatchPendingHeadlessAuthentications(ctx context.Context) (watcher types.Watcher, close func(), err error) { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) +func (c *Cluster) WatchPendingHeadlessAuthentications(ctx context.Context, rootAuthClient auth.ClientI) (watcher types.Watcher, close func(), err error) { + watcher, err = rootAuthClient.WatchPendingHeadlessAuthentications(ctx) if err != nil { return nil, nil, trace.Wrap(err) } - rootClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - proxyClient.Close() - return nil, nil, trace.Wrap(err) - } - - watcher, err = rootClient.WatchPendingHeadlessAuthentications(ctx) - if err != nil { - proxyClient.Close() - rootClient.Close() - return nil, nil, trace.Wrap(err) - } - close = func() { watcher.Close() - proxyClient.Close() - rootClient.Close() } return watcher, close, trace.Wrap(err) } // WatchHeadlessAuthentications watches the backend for headless authentication events for the user. -func (c *Cluster) WatchHeadlessAuthentications(ctx context.Context) (watcher types.Watcher, close func(), err error) { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - rootClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - proxyClient.Close() - return nil, nil, trace.Wrap(err) - } - +func (c *Cluster) WatchHeadlessAuthentications(ctx context.Context, rootAuthClient auth.ClientI) (watcher types.Watcher, close func(), err error) { watch := types.Watch{ Kinds: []types.WatchKind{{ Kind: types.KindHeadlessAuthentication, @@ -81,17 +54,13 @@ func (c *Cluster) WatchHeadlessAuthentications(ctx context.Context) (watcher typ }}, } - watcher, err = rootClient.NewWatcher(ctx, watch) + watcher, err = rootAuthClient.NewWatcher(ctx, watch) if err != nil { - proxyClient.Close() - rootClient.Close() return nil, nil, trace.Wrap(err) } close = func() { watcher.Close() - proxyClient.Close() - rootClient.Close() } return watcher, close, trace.Wrap(err) @@ -99,25 +68,12 @@ func (c *Cluster) WatchHeadlessAuthentications(ctx context.Context) (watcher typ // UpdateHeadlessAuthenticationState updates the headless authentication matching the given id to the given state. // MFA will be prompted when updating to the approve state. -func (c *Cluster) UpdateHeadlessAuthenticationState(ctx context.Context, headlessID string, state types.HeadlessAuthenticationState) error { +func (c *Cluster) UpdateHeadlessAuthenticationState(ctx context.Context, rootAuthClient auth.ClientI, headlessID string, state types.HeadlessAuthenticationState) error { err := AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - rootClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - defer rootClient.Close() - // If changing state to approved, create an MFA challenge and prompt for MFA. var mfaResponse *proto.MFAAuthenticateResponse if state == types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED { - chall, err := rootClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ + chall, err := rootAuthClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{ ContextUser: &proto.ContextUser{}, }, @@ -135,7 +91,7 @@ func (c *Cluster) UpdateHeadlessAuthenticationState(ctx context.Context, headles } } - err = rootClient.UpdateHeadlessAuthenticationState(ctx, headlessID, state, mfaResponse) + err := rootAuthClient.UpdateHeadlessAuthenticationState(ctx, headlessID, state, mfaResponse) return trace.Wrap(err) }) return trace.Wrap(err) diff --git a/lib/teleterm/clusters/cluster_kubes.go b/lib/teleterm/clusters/cluster_kubes.go index 6745d2c085077..6f53d308e70c4 100644 --- a/lib/teleterm/clusters/cluster_kubes.go +++ b/lib/teleterm/clusters/cluster_kubes.go @@ -48,12 +48,10 @@ type Kube struct { } // GetKubes returns a paginated kubes list -func (c *Cluster) GetKubes(ctx context.Context, r *api.GetKubesRequest) (*GetKubesResponse, error) { +func (c *Cluster) GetKubes(ctx context.Context, authClient auth.ClientI, r *api.GetKubesRequest) (*GetKubesResponse, error) { var ( - page apiclient.ResourcePage[types.KubeCluster] - authClient auth.ClientI - proxyClient *client.ProxyClient - err error + page apiclient.ResourcePage[types.KubeCluster] + err error ) req := &proto.ListResourcesRequest{ @@ -68,19 +66,6 @@ func (c *Cluster) GetKubes(ctx context.Context, r *api.GetKubesRequest) (*GetKub } err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - page, err = apiclient.GetResourcePage[types.KubeCluster](ctx, authClient, req) if err != nil { return trace.Wrap(err) @@ -116,9 +101,9 @@ type GetKubesResponse struct { } // reissueKubeCert issue new certificates for kube cluster and saves them to disk. -func (c *Cluster) reissueKubeCert(ctx context.Context, kubeCluster string) (tls.Certificate, error) { +func (c *Cluster) reissueKubeCert(ctx context.Context, proxyClient *client.ProxyClient, kubeCluster string) (tls.Certificate, error) { // Refresh the certs to account for clusterClient.SiteName pointing at a leaf cluster. - err := c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ + err := proxyClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{ RouteToCluster: c.clusterClient.SiteName, AccessRequests: c.status.ActiveRequests.AccessRequests, }) @@ -126,13 +111,6 @@ func (c *Cluster) reissueKubeCert(ctx context.Context, kubeCluster string) (tls. return tls.Certificate{}, trace.Wrap(err) } - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return tls.Certificate{}, trace.Wrap(err) - } - defer proxyClient.Close() - key, err := proxyClient.IssueUserCertsWithMFA( ctx, client.ReissueParams{ RouteToCluster: c.clusterClient.SiteName, @@ -150,7 +128,7 @@ func (c *Cluster) reissueKubeCert(ctx context.Context, kubeCluster string) (tls. // via the RBAC rules, but we also need to make sure that the user has // access to the cluster with at least one kubernetes_user or kubernetes_group // defined. - rootClusterName, err := c.clusterClient.RootClusterName(ctx) + rootClusterName, err := proxyClient.RootClusterName(ctx) if err != nil { return tls.Certificate{}, trace.Wrap(err) } @@ -176,30 +154,15 @@ func (c *Cluster) reissueKubeCert(ctx context.Context, kubeCluster string) (tls. return cert, nil } -func (c *Cluster) getKube(ctx context.Context, kubeCluster string) (types.KubeCluster, error) { +func (c *Cluster) getKube(ctx context.Context, authClient auth.ClientI, kubeCluster string) (types.KubeCluster, error) { var kubeClusters []types.KubeCluster err := AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - + var err error kubeClusters, err = kubeutils.ListKubeClustersWithFilters(ctx, authClient, proto.ListResourcesRequest{ PredicateExpression: fmt.Sprintf("name == %q", kubeCluster), }) - if err != nil { - return trace.Wrap(err) - } - return nil + return trace.Wrap(err) }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/teleterm/clusters/cluster_leaves.go b/lib/teleterm/clusters/cluster_leaves.go index 7924ef956a267..2ba00bd9d1d23 100644 --- a/lib/teleterm/clusters/cluster_leaves.go +++ b/lib/teleterm/clusters/cluster_leaves.go @@ -25,6 +25,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/teleterm/api/uri" ) @@ -41,22 +42,14 @@ type LeafCluster struct { } // GetLeafClusters returns leaf clusters -func (c *Cluster) GetLeafClusters(ctx context.Context) ([]LeafCluster, error) { - var remoteClusters []types.RemoteCluster - err := AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - remoteClusters, err = proxyClient.GetLeafClusters(ctx) - if err != nil { - return trace.Wrap(err) - } - - return nil +func (c *Cluster) GetLeafClusters(ctx context.Context, rootProxyClient *client.ProxyClient) ([]LeafCluster, error) { + var ( + remoteClusters []types.RemoteCluster + err error + ) + err = AddMetadataToRetryableError(ctx, func() error { + remoteClusters, err = rootProxyClient.GetLeafClusters(ctx) + return trace.Wrap(err) }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/teleterm/clusters/cluster_servers.go b/lib/teleterm/clusters/cluster_servers.go index 456ad72bcfe69..834bb25fdf453 100644 --- a/lib/teleterm/clusters/cluster_servers.go +++ b/lib/teleterm/clusters/cluster_servers.go @@ -42,12 +42,10 @@ type Server struct { } // GetServers returns a paginated list of servers. -func (c *Cluster) GetServers(ctx context.Context, r *api.GetServersRequest) (*GetServersResponse, error) { +func (c *Cluster) GetServers(ctx context.Context, r *api.GetServersRequest, authClient auth.ClientI) (*GetServersResponse, error) { var ( - page apiclient.ResourcePage[types.Server] - authClient auth.ClientI - proxyClient *client.ProxyClient - err error + page apiclient.ResourcePage[types.Server] + err error ) req := &proto.ListResourcesRequest{ @@ -62,19 +60,6 @@ func (c *Cluster) GetServers(ctx context.Context, r *api.GetServersRequest) (*Ge } err = AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err = c.clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err = proxyClient.ConnectToCluster(ctx, c.clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - page, err = apiclient.GetResourcePage[types.Server](ctx, authClient, req) if err != nil { return trace.Wrap(err) diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index e629eb0273082..7d195694c51ae 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/clusters" + "github.com/gravitational/teleport/lib/teleterm/services/clientcache" "github.com/gravitational/teleport/lib/teleterm/services/connectmycomputer" ) @@ -69,6 +70,21 @@ type Config struct { ConnectMyComputerNodeJoinWait *connectmycomputer.NodeJoinWait ConnectMyComputerNodeDelete *connectmycomputer.NodeDelete ConnectMyComputerNodeName *connectmycomputer.NodeName + + ClientCache ClientCache +} + +// ClientCache stores clients keyed by cluster URI. +type ClientCache interface { + // Get returns a client from the cache if there is one, + // otherwise it dials the remote server. + // The caller should not close the returned client. + Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) + // ClearForRoot closes and removes clients from the cache + // for the root cluster and its leaf clusters. + ClearForRoot(clusterURI uri.ResourceURI) error + // Clear closes and removes all clients. + Clear() error } type CreateTshdEventsClientCredsFunc func() (grpc.DialOption, error) @@ -140,5 +156,12 @@ func (c *Config) CheckAndSetDefaults() error { c.ConnectMyComputerNodeName = nodeName } + if c.ClientCache == nil { + c.ClientCache = clientcache.New(clientcache.Config{ + Log: c.Log, + Resolver: c.Storage, + }) + } + return nil } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 783836688da1b..b1cc5d5446a06 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -161,12 +161,17 @@ func (s *Service) ListLeafClusters(ctx context.Context, uri string) ([]clusters. return nil, trace.Wrap(err) } + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + // leaf cluster cannot have own leaves if cluster.URI.GetLeafClusterName() != "" { return nil, nil } - leaves, err := cluster.GetLeafClusters(ctx) + leaves, err := cluster.GetLeafClusters(ctx, proxyClient) if err != nil { return nil, trace.Wrap(err) } @@ -243,7 +248,12 @@ func (s *Service) ResolveClusterWithDetails(ctx context.Context, uri string) (*c return nil, nil, trace.Wrap(err) } - withDetails, err := cluster.GetWithDetails(ctx) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + withDetails, err := cluster.GetWithDetails(ctx, proxyClient.CurrentCluster()) if err != nil { return nil, nil, trace.Wrap(err) } @@ -266,7 +276,7 @@ func (s *Service) ClusterLogout(ctx context.Context, uri string) error { return trace.Wrap(err) } - return nil + return trace.Wrap(s.ClearCachedClientsForRoot(cluster.URI)) } // CreateGateway creates a gateway to given targetURI @@ -297,6 +307,11 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) return gateway, nil } + proxyClient, err := s.GetCachedClient(ctx, targetURI.GetClusterURI()) + if err != nil { + return nil, trace.Wrap(err) + } + clusterCreateGatewayParams := clusters.CreateGatewayParams{ TargetURI: targetURI, TargetUser: params.TargetUser, @@ -305,6 +320,7 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) OnExpiredCert: s.reissueGatewayCerts, KubeconfigsDir: s.cfg.KubeconfigsDir, MFAPromptConstructor: s.NewMFAPromptConstructor(targetURI.String()), + ProxyClient: proxyClient, } gateway, err := s.cfg.GatewayCreator.CreateGateway(ctx, clusterCreateGatewayParams) @@ -344,7 +360,12 @@ func (s *Service) reissueGatewayCerts(ctx context.Context, g gateway.Gateway) (t return trace.Wrap(err) } - cert, err = cluster.ReissueGatewayCerts(ctx, g) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return trace.Wrap(err) + } + + cert, err = cluster.ReissueGatewayCerts(ctx, proxyClient, g) if err != nil { return trace.Wrap(err) } @@ -531,7 +552,12 @@ func (s *Service) GetServers(ctx context.Context, req *api.GetServersRequest) (* return nil, trace.Wrap(err) } - response, err := cluster.GetServers(ctx, req) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := cluster.GetServers(ctx, req, proxyClient.CurrentCluster()) if err != nil { return nil, trace.Wrap(err) } @@ -545,7 +571,12 @@ func (s *Service) GetRequestableRoles(ctx context.Context, req *api.GetRequestab return nil, trace.Wrap(err) } - response, err := cluster.GetRequestableRoles(ctx, req) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := cluster.GetRequestableRoles(ctx, req, proxyClient.CurrentCluster()) if err != nil { return nil, trace.Wrap(err) } @@ -558,27 +589,19 @@ func (s *Service) GetRequestableRoles(ctx context.Context, req *api.GetRequestab // PromoteAccessRequest promotes an access request to an access list. func (s *Service) PromoteAccessRequest(ctx context.Context, rootClusterURI uri.ResourceURI, req *accesslistv1.AccessRequestPromoteRequest) (*clusters.AccessRequest, error) { - cluster, clusterClient, err := s.ResolveClusterURI(rootClusterURI) + cluster, _, err := s.ResolveClusterURI(rootClusterURI) + if err != nil { + return nil, trace.Wrap(err) + } + + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) if err != nil { return nil, trace.Wrap(err) } var response *clusters.AccessRequest err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - - promoteResponse, err := authClient.AccessListClient().AccessRequestPromote(ctx, req) + promoteResponse, err := proxyClient.CurrentCluster().AccessListClient().AccessRequestPromote(ctx, req) if err != nil { return trace.Wrap(err) } @@ -595,25 +618,14 @@ func (s *Service) PromoteAccessRequest(ctx context.Context, rootClusterURI uri.R // GetSuggestedAccessLists returns suggested access lists for an access request. func (s *Service) GetSuggestedAccessLists(ctx context.Context, rootClusterURI uri.ResourceURI, accessRequestID string) ([]*accesslist.AccessList, error) { - _, clusterClient, err := s.ResolveClusterURI(rootClusterURI) + proxyClient, err := s.GetCachedClient(ctx, rootClusterURI) if err != nil { return nil, trace.Wrap(err) } var response []*accesslist.AccessList err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() + authClient := proxyClient.CurrentCluster() accessLists, err := authClient.AccessListClient().GetSuggestedAccessLists(ctx, accessRequestID) if err != nil { @@ -632,7 +644,13 @@ func (s *Service) GetAccessRequests(ctx context.Context, req *api.GetAccessReque if err != nil { return nil, trace.Wrap(err) } - response, err := cluster.GetAccessRequests(ctx, types.AccessRequestFilter{}) + + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := cluster.GetAccessRequests(ctx, proxyClient.CurrentCluster(), types.AccessRequestFilter{}) if err != nil { return nil, trace.Wrap(err) } @@ -651,7 +669,12 @@ func (s *Service) GetAccessRequest(ctx context.Context, req *api.GetAccessReques return nil, trace.Wrap(err) } - response, err := cluster.GetAccessRequest(ctx, types.AccessRequestFilter{ + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := cluster.GetAccessRequest(ctx, proxyClient.CurrentCluster(), types.AccessRequestFilter{ ID: req.AccessRequestId, }) if err != nil { @@ -667,7 +690,13 @@ func (s *Service) CreateAccessRequest(ctx context.Context, req *api.CreateAccess if err != nil { return nil, trace.Wrap(err) } - request, err := cluster.CreateAccessRequest(ctx, req) + + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + request, err := cluster.CreateAccessRequest(ctx, proxyClient.CurrentCluster(), req) if err != nil { return nil, trace.Wrap(err) } @@ -680,7 +709,13 @@ func (s *Service) ReviewAccessRequest(ctx context.Context, req *api.ReviewAccess if err != nil { return nil, trace.Wrap(err) } - response, err := cluster.ReviewAccessRequest(ctx, req) + + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := cluster.ReviewAccessRequest(ctx, proxyClient.CurrentCluster(), req) if err != nil { return nil, trace.Wrap(err) } @@ -698,12 +733,12 @@ func (s *Service) DeleteAccessRequest(ctx context.Context, req *api.DeleteAccess return trace.Wrap(err) } - err = cluster.DeleteAccessRequest(ctx, req) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) if err != nil { return trace.Wrap(err) } - return nil + return trace.Wrap(cluster.DeleteAccessRequest(ctx, proxyClient.CurrentCluster(), req)) } func (s *Service) AssumeRole(ctx context.Context, req *api.AssumeRoleRequest) error { @@ -712,12 +747,17 @@ func (s *Service) AssumeRole(ctx context.Context, req *api.AssumeRoleRequest) er return trace.Wrap(err) } - err = cluster.AssumeRole(ctx, req) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) if err != nil { return trace.Wrap(err) } - return nil + if err := cluster.AssumeRole(ctx, proxyClient, req); err != nil { + return trace.Wrap(err) + } + + // We have to reconnect using the updated cert. + return trace.Wrap(s.ClearCachedClientsForRoot(cluster.URI)) } // GetKubes accepts parameterized input to enable searching, sorting, and pagination. @@ -727,7 +767,12 @@ func (s *Service) GetKubes(ctx context.Context, req *api.GetKubesRequest) (*clus return nil, trace.Wrap(err) } - response, err := cluster.GetKubes(ctx, req) + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } + + response, err := cluster.GetKubes(ctx, proxyClient.CurrentCluster(), req) if err != nil { return nil, trace.Wrap(err) } @@ -757,6 +802,10 @@ func (s *Service) Stop() { s.StopHeadlessWatchers() + if err := s.cfg.ClientCache.Clear(); err != nil { + s.cfg.Log.WithError(err).Error("Failed to close remote clients") + } + timeoutCtx, cancel := context.WithTimeout(s.closeContext, time.Second*10) defer cancel() @@ -825,26 +874,19 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq // teleport.dev/connect-my-computer/owner: and allows logging in to those nodes as // the current system user. func (s *Service) CreateConnectMyComputerRole(ctx context.Context, req *api.CreateConnectMyComputerRoleRequest) (*api.CreateConnectMyComputerRoleResponse, error) { - cluster, clusterClient, err := s.ResolveCluster(req.RootClusterUri) + cluster, _, err := s.ResolveCluster(req.RootClusterUri) if err != nil { return nil, trace.Wrap(err) } - response := &api.CreateConnectMyComputerRoleResponse{} - err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } - result, err := s.cfg.ConnectMyComputerRoleSetup.Run(ctx, authClient, proxyClient, cluster) + response := &api.CreateConnectMyComputerRoleResponse{} + err = clusters.AddMetadataToRetryableError(ctx, func() error { + result, err := s.cfg.ConnectMyComputerRoleSetup.Run(ctx, proxyClient.CurrentCluster(), proxyClient, cluster) if err != nil { return trace.Wrap(err) } @@ -857,26 +899,19 @@ func (s *Service) CreateConnectMyComputerRole(ctx context.Context, req *api.Crea // CreateConnectMyComputerNodeToken creates a node join token that is valid for 5 minutes. func (s *Service) CreateConnectMyComputerNodeToken(ctx context.Context, rootClusterUri string) (string, error) { - cluster, clusterClient, err := s.ResolveCluster(rootClusterUri) + cluster, _, err := s.ResolveCluster(rootClusterUri) if err != nil { return "", trace.Wrap(err) } - var nodeToken string - err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return "", trace.Wrap(err) + } - nodeToken, err = s.cfg.ConnectMyComputerTokenProvisioner.CreateNodeToken(ctx, authClient, cluster) + var nodeToken string + err = clusters.AddMetadataToRetryableError(ctx, func() error { + nodeToken, err = s.cfg.ConnectMyComputerTokenProvisioner.CreateNodeToken(ctx, proxyClient.CurrentCluster(), cluster) return trace.Wrap(err) }) @@ -885,25 +920,18 @@ func (s *Service) CreateConnectMyComputerNodeToken(ctx context.Context, rootClus // DeleteConnectMyComputerNode deletes the Connect My Computer node. func (s *Service) DeleteConnectMyComputerNode(ctx context.Context, req *api.DeleteConnectMyComputerNodeRequest) (*api.DeleteConnectMyComputerNodeResponse, error) { - cluster, clusterClient, err := s.ResolveCluster(req.GetRootClusterUri()) + cluster, _, err := s.ResolveCluster(req.GetRootClusterUri()) if err != nil { return &api.DeleteConnectMyComputerNodeResponse{}, trace.Wrap(err) } - err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return nil, trace.Wrap(err) + } - err = s.cfg.ConnectMyComputerNodeDelete.Run(ctx, authClient, cluster) + err = clusters.AddMetadataToRetryableError(ctx, func() error { + err = s.cfg.ConnectMyComputerNodeDelete.Run(ctx, proxyClient.CurrentCluster(), cluster) return trace.Wrap(err) }) @@ -924,27 +952,19 @@ func (s *Service) GetConnectMyComputerNodeName(req *api.GetConnectMyComputerNode // WaitForConnectMyComputerNodeJoin returns a response only after detecting that a Connect My // Computer node for the given cluster has joined the cluster. func (s *Service) WaitForConnectMyComputerNodeJoin(ctx context.Context, rootClusterURI uri.ResourceURI) (clusters.Server, error) { - cluster, clusterClient, err := s.ResolveClusterURI(rootClusterURI) + cluster, _, err := s.ResolveClusterURI(rootClusterURI) + if err != nil { + return clusters.Server{}, trace.Wrap(err) + } + + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) if err != nil { return clusters.Server{}, trace.Wrap(err) } var server clusters.Server err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - - server, err = s.cfg.ConnectMyComputerNodeJoinWait.Run(ctx, authClient, cluster) + server, err = s.cfg.ConnectMyComputerNodeJoinWait.Run(ctx, proxyClient.CurrentCluster(), cluster) if err != nil { return trace.Wrap(err) } @@ -957,7 +977,12 @@ func (s *Service) WaitForConnectMyComputerNodeJoin(ctx context.Context, rootClus // ListUnifiedResources returns resources for the given cluster and search params. func (s *Service) ListUnifiedResources(ctx context.Context, clusterURI uri.ResourceURI, req *proto.ListUnifiedResourcesRequest) (*unifiedresources.ListResponse, error) { - cluster, clusterClient, err := s.ResolveClusterURI(clusterURI) + cluster, _, err := s.ResolveClusterURI(clusterURI) + if err != nil { + return nil, trace.Wrap(err) + } + + proxyClient, err := s.GetCachedClient(ctx, clusterURI) if err != nil { return nil, trace.Wrap(err) } @@ -965,20 +990,7 @@ func (s *Service) ListUnifiedResources(ctx context.Context, clusterURI uri.Resou var resources *unifiedresources.ListResponse err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := clusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - authClient, err := proxyClient.ConnectToCluster(ctx, clusterClient.SiteName) - if err != nil { - return trace.Wrap(err) - } - defer authClient.Close() - - resources, err = unifiedresources.List(ctx, cluster, authClient, req) + resources, err = unifiedresources.List(ctx, cluster, proxyClient.CurrentCluster(), req) if err != nil { return trace.Wrap(err) } @@ -991,35 +1003,24 @@ func (s *Service) ListUnifiedResources(ctx context.Context, clusterURI uri.Resou // GetUserPreferences returns the preferences for a given user. func (s *Service) GetUserPreferences(ctx context.Context, clusterURI uri.ResourceURI) (*api.UserPreferences, error) { - _, rootClusterClient, err := s.ResolveClusterURI(clusterURI.GetRootClusterURI()) + rootProxyClient, err := s.GetCachedClient(ctx, clusterURI.GetRootClusterURI()) if err != nil { return nil, trace.Wrap(err) } - leafClusterName := clusterURI.GetLeafClusterName() var preferences *api.UserPreferences err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := rootClusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - rootAuthClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - defer rootAuthClient.Close() + rootAuthClient := rootProxyClient.CurrentCluster() var leafAuthClient auth.ClientI - if leafClusterName != "" { - leafAuthClient, err = proxyClient.ConnectToCluster(ctx, leafClusterName) + if clusterURI.IsLeaf() { + leafProxyClient, err := s.GetCachedClient(ctx, clusterURI.GetClusterURI()) if err != nil { return trace.Wrap(err) } - defer leafAuthClient.Close() + + leafAuthClient = leafProxyClient.CurrentCluster() } preferences, err = userpreferences.Get(ctx, rootAuthClient, leafAuthClient) @@ -1035,35 +1036,24 @@ func (s *Service) GetUserPreferences(ctx context.Context, clusterURI uri.Resourc // UpdateUserPreferences updates the preferences for a given user. func (s *Service) UpdateUserPreferences(ctx context.Context, clusterURI uri.ResourceURI, newPreferences *api.UserPreferences) (*api.UserPreferences, error) { - _, rootClusterClient, err := s.ResolveClusterURI(clusterURI.GetRootClusterURI()) + rootProxyClient, err := s.GetCachedClient(ctx, clusterURI.GetRootClusterURI()) if err != nil { return nil, trace.Wrap(err) } - leafClusterName := clusterURI.GetLeafClusterName() var preferences *api.UserPreferences err = clusters.AddMetadataToRetryableError(ctx, func() error { - //nolint:staticcheck // SA1019. TODO(tross) update to use ClusterClient - proxyClient, err := rootClusterClient.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - rootAuthClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - defer rootAuthClient.Close() + rootAuthClient := rootProxyClient.CurrentCluster() var leafAuthClient auth.ClientI - if leafClusterName != "" { - leafAuthClient, err = proxyClient.ConnectToCluster(ctx, leafClusterName) + if clusterURI.IsLeaf() { + leafProxyClient, err := s.GetCachedClient(ctx, clusterURI.GetClusterURI()) if err != nil { return trace.Wrap(err) } - defer leafAuthClient.Close() + + leafAuthClient = leafProxyClient.CurrentCluster() } preferences, err = userpreferences.Update(ctx, rootAuthClient, leafAuthClient, newPreferences) @@ -1091,6 +1081,19 @@ func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gat return nil, false } +// GetCachedClient returns a client from the cache if it exists, +// otherwise it dials the remote server. +func (s *Service) GetCachedClient(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) { + clt, err := s.cfg.ClientCache.Get(ctx, clusterURI) + return clt, trace.Wrap(err) +} + +// ClearCachedClientsForRoot closes and removes clients from the cache +// for the root cluster and its leaf clusters. +func (s *Service) ClearCachedClientsForRoot(clusterURI uri.ResourceURI) error { + return trace.Wrap(s.cfg.ClientCache.ClearForRoot(clusterURI)) +} + // Service is the daemon service type Service struct { cfg *Config diff --git a/lib/teleterm/daemon/daemon_headless.go b/lib/teleterm/daemon/daemon_headless.go index 3f76efb0c4265..315db4a4af3e5 100644 --- a/lib/teleterm/daemon/daemon_headless.go +++ b/lib/teleterm/daemon/daemon_headless.go @@ -34,29 +34,34 @@ import ( ) // UpdateHeadlessAuthenticationState updates a headless authentication state. -func (s *Service) UpdateHeadlessAuthenticationState(ctx context.Context, clusterURI, headlessID string, state api.HeadlessAuthenticationState) error { - cluster, _, err := s.ResolveCluster(clusterURI) +func (s *Service) UpdateHeadlessAuthenticationState(ctx context.Context, rootClusterURI, headlessID string, state api.HeadlessAuthenticationState) error { + cluster, _, err := s.ResolveCluster(rootClusterURI) if err != nil { return trace.Wrap(err) } - if err := cluster.UpdateHeadlessAuthenticationState(ctx, headlessID, types.HeadlessAuthenticationState(state)); err != nil { + proxyClient, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return trace.Wrap(err) + } + + if err := cluster.UpdateHeadlessAuthenticationState(ctx, proxyClient.CurrentCluster(), headlessID, types.HeadlessAuthenticationState(state)); err != nil { return trace.Wrap(err) } return nil } -// StartHeadlessHandlers starts a headless watcher for the given cluster URI. +// StartHeadlessWatcher starts a headless watcher for the given cluster URI. // // If waitInit is true, this method will wait for the watcher to connect to the // Auth Server and receive an OpInit event to indicate that the watcher is fully // initialized and ready to catch headless events. -func (s *Service) StartHeadlessWatcher(uri string, waitInit bool) error { +func (s *Service) StartHeadlessWatcher(rootClusterURI string, waitInit bool) error { s.headlessWatcherClosersMu.Lock() defer s.headlessWatcherClosersMu.Unlock() - cluster, _, err := s.ResolveCluster(uri) + cluster, _, err := s.ResolveCluster(rootClusterURI) if err != nil { return trace.Wrap(err) } @@ -92,10 +97,10 @@ func (s *Service) StartHeadlessWatchers() error { // If waitInit is true, this method will wait for the watcher to connect to the // Auth Server and receive an OpInit event to indicate that the watcher is fully // initialized and ready to catch headless events. -func (s *Service) startHeadlessWatcher(cluster *clusters.Cluster, waitInit bool) error { +func (s *Service) startHeadlessWatcher(rootCluster *clusters.Cluster, waitInit bool) error { // If there is already a watcher for this cluster, close and replace it. // This may occur after relogin, for example. - if err := s.stopHeadlessWatcher(cluster.URI.String()); err != nil && !trace.IsNotFound(err) { + if err := s.stopHeadlessWatcher(rootCluster.URI.String()); err != nil && !trace.IsNotFound(err) { return trace.Wrap(err) } @@ -112,9 +117,9 @@ func (s *Service) startHeadlessWatcher(cluster *clusters.Cluster, waitInit bool) } watchCtx, watchCancel := context.WithCancel(s.closeContext) - s.headlessWatcherClosers[cluster.URI.String()] = watchCancel + s.headlessWatcherClosers[rootCluster.URI.String()] = watchCancel - log := s.cfg.Log.WithField("cluster", cluster.URI.String()) + log := s.cfg.Log.WithField("cluster", rootCluster.URI.String()) pendingRequests := make(map[string]context.CancelFunc) pendingRequestsMu := sync.Mutex{} @@ -137,13 +142,19 @@ func (s *Service) startHeadlessWatcher(cluster *clusters.Cluster, waitInit bool) pendingWatcherInitializedOnce := sync.Once{} watch := func() error { - pendingWatcher, closePendingWatcher, err := cluster.WatchPendingHeadlessAuthentications(watchCtx) + proxyClient, err := s.GetCachedClient(watchCtx, rootCluster.URI) + if err != nil { + return trace.Wrap(err) + } + authClient := proxyClient.CurrentCluster() + + pendingWatcher, closePendingWatcher, err := rootCluster.WatchPendingHeadlessAuthentications(watchCtx, authClient) if err != nil { return trace.Wrap(err) } defer closePendingWatcher() - resolutionWatcher, closeResolutionWatcher, err := cluster.WatchHeadlessAuthentications(watchCtx) + resolutionWatcher, closeResolutionWatcher, err := rootCluster.WatchHeadlessAuthentications(watchCtx, authClient) if err != nil { return trace.Wrap(err) } @@ -190,7 +201,7 @@ func (s *Service) startHeadlessWatcher(cluster *clusters.Cluster, waitInit bool) // We do this in a goroutine so the watch loop can continue and cancel resolved requests. go func() { defer cancelSend() - if err := s.sendPendingHeadlessAuthentication(sendCtx, ha, cluster.URI.String()); err != nil { + if err := s.sendPendingHeadlessAuthentication(sendCtx, ha, rootCluster.URI.String()); err != nil { if !strings.Contains(err.Error(), context.Canceled.Error()) && !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) { log.WithError(err).Debug("sendPendingHeadlessAuthentication resulted in unexpected error.") } @@ -233,14 +244,14 @@ func (s *Service) startHeadlessWatcher(cluster *clusters.Cluster, waitInit bool) // watcher was canceled by an outside call to stopHeadlessWatcher. default: // watcher closed due to error or cluster disconnect. - if err := s.stopHeadlessWatcher(cluster.URI.String()); err != nil { + if err := s.stopHeadlessWatcher(rootCluster.URI.String()); err != nil { log.WithError(err).Debug("Failed to remove headless watcher.") } } }() for { - if !cluster.Connected() { + if !rootCluster.Connected() { log.Debugf("Not connected to cluster. Returning from headless watch loop.") return } @@ -276,9 +287,9 @@ func (s *Service) startHeadlessWatcher(cluster *clusters.Cluster, waitInit bool) } // sendPendingHeadlessAuthentication notifies the Electron App of a pending headless authentication. -func (s *Service) sendPendingHeadlessAuthentication(ctx context.Context, ha *types.HeadlessAuthentication, clusterURI string) error { +func (s *Service) sendPendingHeadlessAuthentication(ctx context.Context, ha *types.HeadlessAuthentication, rootClusterURI string) error { req := &api.SendPendingHeadlessAuthenticationRequest{ - RootClusterUri: clusterURI, + RootClusterUri: rootClusterURI, HeadlessAuthenticationId: ha.GetName(), HeadlessAuthenticationClientIp: ha.ClientIpAddress, } diff --git a/lib/teleterm/daemon/daemon_test.go b/lib/teleterm/daemon/daemon_test.go index 44f79f6ba0336..ebda3c15a3706 100644 --- a/lib/teleterm/daemon/daemon_test.go +++ b/lib/teleterm/daemon/daemon_test.go @@ -272,6 +272,7 @@ func TestGatewayCRUD(t *testing.T) { GatewayCreator: mockGatewayCreator, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + ClientCache: fakeClientCache{}, }) require.NoError(t, err) @@ -450,6 +451,7 @@ func TestRetryWithRelogin(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + ClientCache: fakeClientCache{}, }) require.NoError(t, err) @@ -500,6 +502,7 @@ func TestImportantModalSemaphore(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + ClientCache: fakeClientCache{}, }) require.NoError(t, err) @@ -648,6 +651,7 @@ func TestGetGatewayCLICommand(t *testing.T) { }, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + ClientCache: fakeClientCache{}, }) require.NoError(t, err) @@ -724,3 +728,11 @@ type fakeStorage struct { func (f fakeStorage) GetByResourceURI(resourceURI uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) { return &clusters.Cluster{}, &client.TeleportClient{}, nil } + +type fakeClientCache struct { + ClientCache +} + +func (f fakeClientCache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) { + return &client.ProxyClient{}, nil +} diff --git a/lib/teleterm/services/clientcache/clientcache.go b/lib/teleterm/services/clientcache/clientcache.go new file mode 100644 index 0000000000000..3345a6dd105e3 --- /dev/null +++ b/lib/teleterm/services/clientcache/clientcache.go @@ -0,0 +1,189 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package clientcache + +import ( + "context" + "sync" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" + + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusters" +) + +// Cache stores clients keyed by cluster URI. +// Safe for concurrent access. +// Closes all clients and wipes the cache on Clear. +type Cache struct { + cfg Config + mu sync.Mutex + // clients keep mapping between cluster URI + // (both root and leaf) and proxy clients + clients map[uri.ResourceURI]*client.ProxyClient + // group prevents duplicate requests to create clients + // for a given cluster URI + group singleflight.Group +} + +// Config describes the client cache configuration. +type Config struct { + Resolver clusters.Resolver + Log logrus.FieldLogger +} + +func (c *Config) checkAndSetDefaults() { + if c.Log == nil { + c.Log = logrus.WithField(trace.Component, "clientcache") + } +} + +// New creates an instance of Cache. +func New(c Config) *Cache { + c.checkAndSetDefaults() + + return &Cache{ + cfg: c, + clients: make(map[uri.ResourceURI]*client.ProxyClient), + } +} + +// Get returns a client from the cache if there is one, +// otherwise it dials the remote server. +// The caller should not close the returned client. +func (c *Cache) Get(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) { + groupClt, err, _ := c.group.Do(clusterURI.String(), func() (any, error) { + if fromCache := c.getFromCache(clusterURI); fromCache != nil { + return fromCache, nil + } + + _, clusterClient, err := c.cfg.Resolver.ResolveCluster(clusterURI) + if err != nil { + return nil, trace.Wrap(err) + } + + //nolint:staticcheck // SA1019. TODO(gzdunek): Update to use client.ClusterClient. + newProxyClient, err := clusterClient.ConnectToProxy(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + // We'll save the client in the cache, so we don't have to + // build a new connection next time. + // All cached clients will be closed when the daemon exits. + if err := c.addToCache(clusterURI, newProxyClient); err != nil { + return nil, trace.NewAggregate(err, newProxyClient.Close()) + } + + c.cfg.Log.WithField("cluster", clusterURI).Info("Added client to cache.") + + return newProxyClient, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + clt, ok := groupClt.(*client.ProxyClient) + if !ok { + return nil, trace.BadParameter("unexpected type %T received for proxy client", groupClt) + } + + return clt, nil +} + +// ClearForRoot closes and removes clients from the cache +// for the root cluster and its leaf clusters. +func (c *Cache) ClearForRoot(clusterURI uri.ResourceURI) error { + c.mu.Lock() + defer c.mu.Unlock() + + rootClusterURI := clusterURI.GetRootClusterURI() + var ( + errors []error + deleted []uri.ResourceURI + ) + + for resourceURI, clt := range c.clients { + if resourceURI.GetRootClusterURI() == rootClusterURI { + if err := clt.Close(); err != nil { + errors = append(errors, err) + } + deleted = append(deleted, resourceURI.GetClusterURI()) + delete(c.clients, resourceURI) + } + } + + c.cfg.Log.WithFields(logrus.Fields{"cluster": rootClusterURI, "clients": deleted}).Info("Invalidated cached clients for root cluster.") + + return trace.NewAggregate(errors...) + +} + +// Clear closes and removes all clients. +func (c *Cache) Clear() error { + c.mu.Lock() + defer c.mu.Unlock() + + var errors []error + for _, clt := range c.clients { + if err := clt.Close(); err != nil { + errors = append(errors, err) + } + } + clear(c.clients) + + return trace.NewAggregate(errors...) +} + +func (c *Cache) addToCache(clusterURI uri.ResourceURI, proxyClient *client.ProxyClient) error { + c.mu.Lock() + defer c.mu.Unlock() + + var err error + if c.clients[clusterURI] != nil { + err = c.clients[clusterURI].Close() + } + c.clients[clusterURI] = proxyClient + + // This goroutine removes the connection from the cache when + // it is unexpectedly interrupted (for example, by the remote site). + // It will also react to client.Close() called from our side, but it will be noop. + go func() { + err := proxyClient.Client.Wait() + c.mu.Lock() + defer c.mu.Unlock() + + if c.clients[clusterURI] != proxyClient { + return + } + + delete(c.clients, clusterURI) + c.cfg.Log.WithField("cluster", clusterURI).WithError(err).Info("Connection has been closed, removed client from cache.") + }() + return trace.Wrap(err) +} + +func (c *Cache) getFromCache(clusterURI uri.ResourceURI) *client.ProxyClient { + c.mu.Lock() + defer c.mu.Unlock() + + clt := c.clients[clusterURI] + return clt +} diff --git a/web/packages/teleterm/src/ui/utils/retryWithRelogin.ts b/web/packages/teleterm/src/ui/utils/retryWithRelogin.ts index 5c05b7b0d733a..8237fc8e2febc 100644 --- a/web/packages/teleterm/src/ui/utils/retryWithRelogin.ts +++ b/web/packages/teleterm/src/ui/utils/retryWithRelogin.ts @@ -97,7 +97,9 @@ export function isRetryable(error: unknown): boolean { return ( error instanceof Error && (error.message.includes('ssh: handshake failed') || - error.message.includes('ssh: cert has expired')) + error.message.includes('ssh: cert has expired') || + error.message.includes('tls: expired certificate') || + error.message.includes('client credentials have expired')) ); }