diff --git a/lib/cache/cache.go b/lib/cache/cache.go index eb2ea24c33310..c1f6e4e87cc41 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -2409,86 +2409,6 @@ func (c *Cache) GetProxies() ([]types.Server, error) { return rg.reader.GetProxies() } -type remoteClustersCacheKey struct { - name string -} - -// GetRemoteClusters returns a list of remote clusters -func (c *Cache) GetRemoteClusters(ctx context.Context) ([]types.RemoteCluster, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetRemoteClusters") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.remoteClusters) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - if !rg.IsCacheRead() { - cachedRemotes, err := utils.FnCacheGet(ctx, c.fnCache, remoteClustersCacheKey{}, func(ctx context.Context) ([]types.RemoteCluster, error) { - remotes, err := rg.reader.GetRemoteClusters(ctx) - return remotes, err - }) - if err != nil || cachedRemotes == nil { - return nil, trace.Wrap(err) - } - - remotes := make([]types.RemoteCluster, 0, len(cachedRemotes)) - for _, remote := range cachedRemotes { - remotes = append(remotes, remote.Clone()) - } - return remotes, nil - } - return rg.reader.GetRemoteClusters(ctx) -} - -// GetRemoteCluster returns a remote cluster by name -func (c *Cache) GetRemoteCluster(ctx context.Context, clusterName string) (types.RemoteCluster, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetRemoteCluster") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.remoteClusters) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - if !rg.IsCacheRead() { - cachedRemote, err := utils.FnCacheGet(ctx, c.fnCache, remoteClustersCacheKey{clusterName}, func(ctx context.Context) (types.RemoteCluster, error) { - remote, err := rg.reader.GetRemoteCluster(ctx, clusterName) - return remote, err - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return cachedRemote.Clone(), nil - } - rc, err := rg.reader.GetRemoteCluster(ctx, clusterName) - if trace.IsNotFound(err) && rg.IsCacheRead() { - // release read lock early - rg.Release() - // fallback is sane because this method is never used - // in construction of derivative caches. - if rc, err := c.Config.Trust.GetRemoteCluster(ctx, clusterName); err == nil { - return rc, nil - } - } - return rc, trace.Wrap(err) -} - -// ListRemoteClusters returns a page of remote clusters. -func (c *Cache) ListRemoteClusters(ctx context.Context, pageSize int, nextToken string) ([]types.RemoteCluster, string, error) { - _, span := c.Tracer.Start(ctx, "cache/ListRemoteClusters") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.remoteClusters) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - remoteClusters, token, err := rg.reader.ListRemoteClusters(ctx, pageSize, nextToken) - return remoteClusters, token, trace.Wrap(err) -} - // GetUser is a part of auth.Cache implementation. func (c *Cache) GetUser(ctx context.Context, name string, withSecrets bool) (types.User, error) { _, span := c.Tracer.Start(ctx, "cache/GetUser") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index 32f1b5dd26b76..91e864f99b18c 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1852,40 +1852,6 @@ func TestAuthServers(t *testing.T) { }) } -// TestRemoteClusters tests remote clusters caching -func TestRemoteClusters(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForProxy) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[types.RemoteCluster]{ - newResource: func(name string) (types.RemoteCluster, error) { - return types.NewRemoteCluster(name) - }, - create: func(ctx context.Context, rc types.RemoteCluster) error { - _, err := p.trustS.CreateRemoteCluster(ctx, rc) - return err - }, - list: func(ctx context.Context) ([]types.RemoteCluster, error) { - return p.trustS.GetRemoteClusters(ctx) - }, - cacheGet: func(ctx context.Context, name string) (types.RemoteCluster, error) { - return p.cache.GetRemoteCluster(ctx, name) - }, - cacheList: func(ctx context.Context) ([]types.RemoteCluster, error) { - return p.cache.GetRemoteClusters(ctx) - }, - update: func(ctx context.Context, rc types.RemoteCluster) error { - _, err := p.trustS.UpdateRemoteCluster(ctx, rc) - return err - }, - deleteAll: func(ctx context.Context) error { - return p.trustS.DeleteAllRemoteClusters(ctx) - }, - }) -} - // TestKubernetes tests that CRUD operations on kubernetes clusters resources are // replicated from the backend to the cache. func TestKubernetes(t *testing.T) { diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 81e559b94c13a..6eb240ae878c5 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -49,7 +49,8 @@ type collectionHandler interface { type collections struct { byKind map[resourceKind]collectionHandler - botInstances *collection[*machineidv1.BotInstance, botInstanceIndex] + botInstances *collection[*machineidv1.BotInstance, botInstanceIndex] + remoteClusters *collection[types.RemoteCluster, remoteClusterIndex] } // isKnownUncollectedKind is true if a resource kind is not stored in @@ -87,13 +88,20 @@ func setupCollections(c Config, legacyCollections map[resourceKind]legacyCollect out.botInstances = collect out.byKind[resourceKind] = out.botInstances + case types.KindRemoteCluster: + collect, err := newRemoteClusterCollection(c.Trust, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.remoteClusters = collect + out.byKind[resourceKind] = out.remoteClusters default: _, legacyOk := legacyCollections[resourceKind] if _, ok := out.byKind[resourceKind]; !ok && !legacyOk { return nil, trace.BadParameter("resource %q is not supported", watch.Kind) } } - } return out, nil diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index d8ddb27545228..be892ef2040fe 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -356,15 +356,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.tunnelConnections - case types.KindRemoteCluster: - if c.Presence == nil { - return nil, trace.BadParameter("missing parameter Presence") - } - collections.remoteClusters = &genericCollection[types.RemoteCluster, remoteClusterGetter, remoteClusterExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.remoteClusters case types.KindAppServer: if c.Presence == nil { return nil, trace.BadParameter("missing parameter Presence") diff --git a/lib/cache/remote_cluster.go b/lib/cache/remote_cluster.go new file mode 100644 index 0000000000000..d907df50365fe --- /dev/null +++ b/lib/cache/remote_cluster.go @@ -0,0 +1,166 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/clientutils" + "github.com/gravitational/teleport/lib/itertools/stream" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" +) + +type remoteClusterIndex string + +const remoteClusterNameIndex remoteClusterIndex = "name" + +func newRemoteClusterCollection(upstream services.Trust, w types.WatchKind) (*collection[types.RemoteCluster, remoteClusterIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter Trust") + } + + return &collection[types.RemoteCluster, remoteClusterIndex]{ + store: newStore( + types.KindRemoteCluster, + types.RemoteCluster.Clone, + map[remoteClusterIndex]func(types.RemoteCluster) string{ + remoteClusterNameIndex: types.RemoteCluster.GetName, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.RemoteCluster, error) { + out, err := stream.Collect(clientutils.Resources(ctx, upstream.ListRemoteClusters)) + return out, trace.Wrap(err) + }, + headerTransform: func(hdr *types.ResourceHeader) types.RemoteCluster { + return &types.RemoteClusterV3{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +type remoteClustersCacheKey struct { + name string +} + +// GetRemoteClusters returns a list of remote clusters +func (c *Cache) GetRemoteClusters(ctx context.Context) ([]types.RemoteCluster, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetRemoteClusters") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.remoteClusters) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + + if rg.ReadCache() { + remotes := make([]types.RemoteCluster, 0, rg.store.len()) + for rc := range rg.store.resources(remoteClusterNameIndex, "", "") { + remotes = append(remotes, rc.Clone()) + } + + return remotes, nil + } + + cachedRemotes, err := utils.FnCacheGet(ctx, c.fnCache, remoteClustersCacheKey{}, func(ctx context.Context) ([]types.RemoteCluster, error) { + var out []types.RemoteCluster + var startKey string + + for { + clusters, next, err := c.Config.Trust.ListRemoteClusters(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + out = append(out, clusters...) + startKey = next + if next == "" { + break + } + } + + return out, nil + }) + if err != nil || cachedRemotes == nil { + return nil, trace.Wrap(err) + } + + remotes := make([]types.RemoteCluster, 0, len(cachedRemotes)) + for _, remote := range cachedRemotes { + remotes = append(remotes, remote.Clone()) + } + return remotes, nil +} + +// GetRemoteCluster returns a remote cluster by name +func (c *Cache) GetRemoteCluster(ctx context.Context, clusterName string) (types.RemoteCluster, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetRemoteCluster") + defer span.End() + + var upstreamRead bool + getter := genericGetter[types.RemoteCluster, remoteClusterIndex]{ + cache: c, + collection: c.collections.remoteClusters, + index: remoteClusterNameIndex, + upstreamGet: func(ctx context.Context, clusterName string) (types.RemoteCluster, error) { + upstreamRead = true + cachedRemote, err := utils.FnCacheGet(ctx, c.fnCache, remoteClustersCacheKey{clusterName}, func(ctx context.Context) (types.RemoteCluster, error) { + remote, err := c.Config.Trust.GetRemoteCluster(ctx, clusterName) + return remote, err + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return cachedRemote.Clone(), nil + }, + } + out, err := getter.get(ctx, clusterName) + if trace.IsNotFound(err) && !upstreamRead { + // fallback is sane because this method is never used + // in construction of derivative caches. + if rc, err := c.Config.Trust.GetRemoteCluster(ctx, clusterName); err == nil { + return rc, nil + } + } + return out, trace.Wrap(err) +} + +// ListRemoteClusters returns a page of remote clusters. +func (c *Cache) ListRemoteClusters(ctx context.Context, pageSize int, nextToken string) ([]types.RemoteCluster, string, error) { + _, span := c.Tracer.Start(ctx, "cache/ListRemoteClusters") + defer span.End() + + lister := genericLister[types.RemoteCluster, remoteClusterIndex]{ + cache: c, + collection: c.collections.remoteClusters, + index: remoteClusterNameIndex, + upstreamList: c.Config.Trust.ListRemoteClusters, + nextToken: types.RemoteCluster.GetName, + } + out, next, err := lister.list(ctx, pageSize, nextToken) + return out, next, trace.Wrap(err) +} diff --git a/lib/cache/remote_cluster_test.go b/lib/cache/remote_cluster_test.go new file mode 100644 index 0000000000000..3947e4bbf5a05 --- /dev/null +++ b/lib/cache/remote_cluster_test.go @@ -0,0 +1,126 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/clientutils" + "github.com/gravitational/teleport/lib/itertools/stream" +) + +// TestRemoteClusters tests remote clusters caching +func TestRemoteClusters(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForProxy) + t.Cleanup(p.Close) + + t.Run("GetRemoteClusters", func(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForProxy) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.RemoteCluster]{ + newResource: func(name string) (types.RemoteCluster, error) { + return types.NewRemoteCluster(name) + }, + create: func(ctx context.Context, rc types.RemoteCluster) error { + _, err := p.trustS.CreateRemoteCluster(ctx, rc) + return err + }, + list: func(ctx context.Context) ([]types.RemoteCluster, error) { + return p.trustS.GetRemoteClusters(ctx) + }, + cacheGet: func(ctx context.Context, name string) (types.RemoteCluster, error) { + return p.cache.GetRemoteCluster(ctx, name) + }, + cacheList: func(ctx context.Context) ([]types.RemoteCluster, error) { + return p.cache.GetRemoteClusters(ctx) + }, + update: func(ctx context.Context, rc types.RemoteCluster) error { + _, err := p.trustS.UpdateRemoteCluster(ctx, rc) + return err + }, + deleteAll: func(ctx context.Context) error { + return p.trustS.DeleteAllRemoteClusters(ctx) + }, + }) + }) + + t.Run("ListRemoteClusters", func(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForProxy) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.RemoteCluster]{ + newResource: func(name string) (types.RemoteCluster, error) { + return types.NewRemoteCluster(name) + }, + create: func(ctx context.Context, rc types.RemoteCluster) error { + _, err := p.trustS.CreateRemoteCluster(ctx, rc) + return err + }, + list: func(ctx context.Context) ([]types.RemoteCluster, error) { + return p.trustS.GetRemoteClusters(ctx) + }, + cacheGet: func(ctx context.Context, name string) (types.RemoteCluster, error) { + return p.cache.GetRemoteCluster(ctx, name) + }, + cacheList: func(ctx context.Context) ([]types.RemoteCluster, error) { + return stream.Collect(clientutils.Resources(ctx, p.cache.ListRemoteClusters)) + }, + update: func(ctx context.Context, rc types.RemoteCluster) error { + _, err := p.trustS.UpdateRemoteCluster(ctx, rc) + return err + }, + deleteAll: func(ctx context.Context) error { + return p.trustS.DeleteAllRemoteClusters(ctx) + }, + }) + + // TODO(smallinsky): Remove this once pagination tests covering this case for each resource type + // have been merged into v17. + t.Run("test cluster get/update", func(t *testing.T) { + item, err := types.NewRemoteCluster("test-cluster") + require.NoError(t, err) + + _, err = p.trustS.CreateRemoteCluster(context.Background(), item) + require.NoError(t, err) + + var itemFromCache types.RemoteCluster + require.EventuallyWithT(t, func(t *assert.CollectT) { + var err error + itemFromCache, err = p.cache.GetRemoteCluster(context.Background(), "test-cluster") + require.NoError(t, err) + }, 2*time.Second, time.Millisecond*40) + + itemFromCache.SetConnectionStatus(teleport.RemoteClusterStatusOffline) + _, err = p.trustS.UpdateRemoteCluster(context.Background(), itemFromCache) + require.NoError(t, err) + }) + }) +}