From 5dee0e007e340b9dfd5a063c34a2fa86e73d4c3c Mon Sep 17 00:00:00 2001 From: Forrest Marshall Date: Thu, 26 Oct 2023 16:27:42 +0000 Subject: [PATCH] improve test cov for auth caches --- lib/auth/accesspoint/accesspoint.go | 155 +++++++++++++++++++++++++++ lib/auth/auth_with_roles_test.go | 160 +++++++++++++++++++++------- lib/auth/helpers.go | 54 ++++++++++ lib/auth/tls_test.go | 23 +++- lib/service/service.go | 114 +++----------------- 5 files changed, 365 insertions(+), 141 deletions(-) create mode 100644 lib/auth/accesspoint/accesspoint.go diff --git a/lib/auth/accesspoint/accesspoint.go b/lib/auth/accesspoint/accesspoint.go new file mode 100644 index 0000000000000..c1fffcda9dc7a --- /dev/null +++ b/lib/auth/accesspoint/accesspoint.go @@ -0,0 +1,155 @@ +/* +Copyright 2015-2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// package accesspoint provides helpers for configuring caches in the context of +// setting up service-level auth access points. this logic has been moved out of +// lib/service in order to facilitate better testing practices. +package accesspoint + +import ( + "context" + "slices" + "time" + + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/cache" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/services" +) + +// AccessCacheConfig holds parameters used to confiure a cache to +// serve as an auth access point for a teleport service. +type AccessCacheConfig struct { + // Context is the base context used to propagate closure to + // cache components. + Context context.Context + // Services is a collection of upstream services from which + // the access cache will derive its state. + Services services.Services + // Setup is a function that takes cache configuration and + // modifies it to support a specific teleport service. + Setup cache.SetupConfigFn + // CacheName identifies the cache in logs. + CacheName []string + // Events is true if cache should have the events system enabled. + Events bool + // Unstarted is true if the cache should not be started. + Unstarted bool + // MaxRetryPeriod is the max retry period between connection attempts + // to auth. + MaxRetryPeriod time.Duration + // ProcessID is an optional identifier used to help disambiguate logs + // when teleport performs in-memory reloads. + ProcessID string + // TracingProvider is the provider to be used for exporting + // traces. No-op tracers will be used if no provider is set. + TracingProvider *tracing.Provider +} + +func (c *AccessCacheConfig) CheckAndSetDefaults() error { + if c.Services == nil { + return trace.BadParameter("missing parameter Services") + } + if c.Setup == nil { + return trace.BadParameter("missing parameter Setup") + } + if len(c.CacheName) == 0 { + return trace.BadParameter("missing parameter CacheName") + } + if c.Context == nil { + c.Context = context.Background() + } + return nil +} + +// NewAccessCache builds a cache.Cache instance for a teleport service. This logic has been +// broken out of lib/service in order to support easier unit testing of process components. +func NewAccessCache(cfg AccessCacheConfig) (*cache.Cache, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + log.Debugf("Creating in-memory backend for %v.", cfg.CacheName) + mem, err := memory.New(memory.Config{ + Context: cfg.Context, + EventsOff: !cfg.Events, + Mirror: true, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var tracer oteltrace.Tracer + if cfg.TracingProvider != nil { + tracer = cfg.TracingProvider.Tracer(teleport.ComponentCache) + } + reporter, err := backend.NewReporter(backend.ReporterConfig{ + Component: teleport.ComponentCache, + Backend: mem, + Tracer: tracer, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + component := slices.Clone(cfg.CacheName) + if cfg.ProcessID != "" { + component = append(component, cfg.ProcessID) + } + + component = append(component, teleport.ComponentCache) + metricComponent := append(slices.Clone(cfg.CacheName), teleport.ComponentCache) + + return cache.New(cfg.Setup(cache.Config{ + Context: cfg.Context, + Backend: reporter, + Events: cfg.Services, + ClusterConfig: cfg.Services, + Provisioner: cfg.Services, + Trust: cfg.Services, + Users: cfg.Services, + Access: cfg.Services, + DynamicAccess: cfg.Services, + Presence: cfg.Services, + Restrictions: cfg.Services, + Apps: cfg.Services, + Kubernetes: cfg.Services, + DatabaseServices: cfg.Services, + Databases: cfg.Services, + AppSession: cfg.Services, + SnowflakeSession: cfg.Services, + SAMLIdPSession: cfg.Services, + WindowsDesktops: cfg.Services, + SAMLIdPServiceProviders: cfg.Services, + UserGroups: cfg.Services, + Okta: cfg.Services.OktaClient(), + SecReports: cfg.Services.SecReportsClient(), + UserLoginStates: cfg.Services.UserLoginStateClient(), + Integrations: cfg.Services, + DiscoveryConfigs: cfg.Services.DiscoveryConfigClient(), + WebSession: cfg.Services.WebSessions(), + WebToken: cfg.Services.WebTokens(), + Component: teleport.Component(component...), + MetricComponent: teleport.Component(metricComponent...), + Tracer: tracer, + MaxRetryPeriod: cfg.MaxRetryPeriod, + Unstarted: cfg.Unstarted, + })) +} diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 0ac247bfb3dab..d3ebe57fcee17 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -4255,7 +4255,12 @@ func TestListResources_WithRoles(t *testing.T) { func TestListUnifiedResources_KindsFilter(t *testing.T) { t.Parallel() ctx := context.Background() - srv := newTestTLSServer(t) + srv := newTestTLSServer(t, withCacheEnabled(true)) + + require.Eventually(t, func() bool { + return srv.Auth().UnifiedResourceCache.IsInitialized() + }, 5*time.Second, 200*time.Millisecond, "unified resource watcher never initialized") + for i := 0; i < 5; i++ { name := uuid.New().String() node, err := types.NewServerWithLabels( @@ -4293,15 +4298,19 @@ func TestListUnifiedResources_KindsFilter(t *testing.T) { require.NoError(t, err) clt, err := srv.NewClient(TestUser(user.GetName())) require.NoError(t, err) - resp, err := clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ - Kinds: []string{types.KindDatabase}, - Limit: 5, - SortBy: types.SortBy{IsDesc: true, Field: types.ResourceMetadataName}, - }) - require.NoError(t, err) - require.Eventually(t, func() bool { + + var resp *proto.ListUnifiedResourcesResponse + inlineEventually(t, func() bool { + var err error + resp, err = clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ + Kinds: []string{types.KindDatabase}, + Limit: 5, + SortBy: types.SortBy{IsDesc: true, Field: types.ResourceMetadataName}, + }) + require.NoError(t, err) return len(resp.Resources) == 5 }, time.Second, time.Second/10) + // Check that all resources are of type KindDatabaseServer for _, resource := range resp.Resources { r := resource.GetDatabaseServer() @@ -4312,7 +4321,12 @@ func TestListUnifiedResources_KindsFilter(t *testing.T) { func TestListUnifiedResources_WithPinnedResources(t *testing.T) { t.Parallel() ctx := context.Background() - srv := newTestTLSServer(t) + srv := newTestTLSServer(t, withCacheEnabled(true)) + + require.Eventually(t, func() bool { + return srv.Auth().UnifiedResourceCache.IsInitialized() + }, 5*time.Second, 200*time.Millisecond, "unified resource watcher never initialized") + names := []string{"tifa", "cloud", "aerith", "baret", "cid", "tifa2"} for _, name := range names { @@ -4355,11 +4369,17 @@ func TestListUnifiedResources_WithPinnedResources(t *testing.T) { clt, err := srv.NewClient(identity) require.NoError(t, err) - resp, err := clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ - PinnedOnly: true, - }) - require.NoError(t, err) - require.Len(t, resp.Resources, 1) + + var resp *proto.ListUnifiedResourcesResponse + inlineEventually(t, func() bool { + var err error + resp, err = clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ + PinnedOnly: true, + }) + require.NoError(t, err) + return len(resp.Resources) == 1 + }, time.Second*5, time.Millisecond*200) + require.Empty(t, resp.NextKey) // Check that our returned resource is the pinned resource require.Equal(t, "tifa", resp.Resources[0].GetNode().GetHostname()) @@ -4370,7 +4390,12 @@ func TestListUnifiedResources_WithPinnedResources(t *testing.T) { func TestListUnifiedResources_WithSearch(t *testing.T) { t.Parallel() ctx := context.Background() - srv := newTestTLSServer(t) + srv := newTestTLSServer(t, withCacheEnabled(true)) + + require.Eventually(t, func() bool { + return srv.Auth().UnifiedResourceCache.IsInitialized() + }, 5*time.Second, 200*time.Millisecond, "unified resource watcher never initialized") + names := []string{"vivi", "cloud", "aerith", "barret", "cid", "vivi2"} for i := 0; i < 6; i++ { name := names[i] @@ -4387,9 +4412,12 @@ func TestListUnifiedResources_WithSearch(t *testing.T) { _, err = srv.Auth().UpsertNode(ctx, node) require.NoError(t, err) } - testNodes, err := srv.Auth().GetNodes(ctx, apidefaults.Namespace) - require.NoError(t, err) - require.Len(t, testNodes, 6) + + inlineEventually(t, func() bool { + testNodes, err := srv.Auth().GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) + return len(testNodes) == 6 + }, time.Second*5, time.Millisecond*200) sp := &types.SAMLIdPServiceProviderV1{ ResourceHeader: types.ResourceHeader{ @@ -4409,13 +4437,19 @@ func TestListUnifiedResources_WithSearch(t *testing.T) { require.NoError(t, err) clt, err := srv.NewClient(TestUser(user.GetName())) require.NoError(t, err) - resp, err := clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ - SearchKeywords: []string{"tifa"}, - Limit: 10, - SortBy: types.SortBy{IsDesc: true, Field: types.ResourceMetadataName}, - }) - require.NoError(t, err) - require.Len(t, resp.Resources, 1) + + var resp *proto.ListUnifiedResourcesResponse + inlineEventually(t, func() bool { + var err error + resp, err = clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ + SearchKeywords: []string{"tifa"}, + Limit: 10, + SortBy: types.SortBy{IsDesc: true, Field: types.ResourceMetadataName}, + }) + require.NoError(t, err) + return len(resp.Resources) == 1 + }, time.Second*5, time.Millisecond*200) + require.Empty(t, resp.NextKey) // Check that our returned resource has the correct name @@ -4430,7 +4464,12 @@ func TestListUnifiedResources_WithSearch(t *testing.T) { func TestListUnifiedResources_MixedAccess(t *testing.T) { t.Parallel() ctx := context.Background() - srv := newTestTLSServer(t) + srv := newTestTLSServer(t, withCacheEnabled(true)) + + require.Eventually(t, func() bool { + return srv.Auth().UnifiedResourceCache.IsInitialized() + }, 5*time.Second, 200*time.Millisecond, "unified resource watcher never initialized") + names := []string{"tifa", "cloud", "aerith", "baret", "cid", "tifa2"} for i := 0; i < 6; i++ { name := names[i] @@ -4475,17 +4514,24 @@ func TestListUnifiedResources_MixedAccess(t *testing.T) { require.NoError(t, err) require.NoError(t, srv.Auth().UpsertWindowsDesktop(ctx, desktop)) } - testNodes, err := srv.Auth().GetNodes(ctx, apidefaults.Namespace) - require.NoError(t, err) - require.Len(t, testNodes, 6) - testDbs, err := srv.Auth().GetDatabaseServers(ctx, apidefaults.Namespace) - require.NoError(t, err) - require.Len(t, testDbs, 6) + inlineEventually(t, func() bool { + testNodes, err := srv.Auth().GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) + return len(testNodes) == 6 + }, time.Second*5, time.Millisecond*200) - testDesktops, err := srv.Auth().GetWindowsDesktops(ctx, types.WindowsDesktopFilter{}) - require.NoError(t, err) - require.Len(t, testDesktops, 6) + inlineEventually(t, func() bool { + testDbs, err := srv.Auth().GetDatabaseServers(ctx, apidefaults.Namespace) + require.NoError(t, err) + return len(testDbs) == 6 + }, time.Second*5, time.Millisecond*200) + + inlineEventually(t, func() bool { + testDesktops, err := srv.Auth().GetWindowsDesktops(ctx, types.WindowsDesktopFilter{}) + require.NoError(t, err) + return len(testDesktops) == 6 + }, time.Second*5, time.Millisecond*200) // create user, role, and client username := "user" @@ -4512,7 +4558,9 @@ func TestListUnifiedResources_MixedAccess(t *testing.T) { clt, err := srv.NewClient(identity) require.NoError(t, err) - require.NoError(t, err) + // ensure updated roles have propagated to auth cache + flushCache(t, srv.Auth()) + resp, err := clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ Limit: 20, SortBy: types.SortBy{IsDesc: true, Field: types.ResourceMetadataName}, @@ -4533,7 +4581,12 @@ func TestListUnifiedResources_MixedAccess(t *testing.T) { func TestListUnifiedResources_WithPredicate(t *testing.T) { t.Parallel() ctx := context.Background() - srv := newTestTLSServer(t) + srv := newTestTLSServer(t, withCacheEnabled(true)) + + require.Eventually(t, func() bool { + return srv.Auth().UnifiedResourceCache.IsInitialized() + }, 5*time.Second, 200*time.Millisecond, "unified resource watcher never initialized") + names := []string{"tifa", "cloud", "aerith", "baret", "cid", "tifa2"} for i := 0; i < 6; i++ { name := names[i] @@ -4552,9 +4605,12 @@ func TestListUnifiedResources_WithPredicate(t *testing.T) { _, err = srv.Auth().UpsertNode(ctx, node) require.NoError(t, err) } - testNodes, err := srv.Auth().GetNodes(ctx, apidefaults.Namespace) - require.NoError(t, err) - require.Len(t, testNodes, 6) + + inlineEventually(t, func() bool { + testNodes, err := srv.Auth().GetNodes(ctx, apidefaults.Namespace) + require.NoError(t, err) + return len(testNodes) == 6 + }, time.Second*5, time.Millisecond*200) // create user, role, and client username := "theuser" @@ -4564,7 +4620,6 @@ func TestListUnifiedResources_WithPredicate(t *testing.T) { clt, err := srv.NewClient(identity) require.NoError(t, err) - require.NoError(t, err) resp, err := clt.ListUnifiedResources(ctx, &proto.ListUnifiedResourcesRequest{ PredicateExpression: `labels.name == "tifa"`, Limit: 10, @@ -6883,3 +6938,26 @@ func TestKubeKeepAliveServer(t *testing.T) { ) } } + +// inlineEventually is equivalent to require.Eventually except that it runs the provided function directly +// instead of in a background goroutine, making it safe to fail the test from within the closure. +func inlineEventually(t *testing.T, cond func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { + t.Helper() + + timer := time.NewTimer(waitFor) + defer timer.Stop() + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for { + select { + case <-timer.C: + require.FailNow(t, "condition never satisfied", msgAndArgs...) + case <-ticker.C: + if cond() { + return + } + } + } +} diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index ee1e3392e5422..64e5179a6947e 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -21,9 +21,11 @@ import ( "crypto/tls" "crypto/x509" "net" + "strings" "testing" "time" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -39,12 +41,14 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/ai/embedding" + "github.com/gravitational/teleport/lib/auth/accesspoint" "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/auth/native" authority "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/cache" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" @@ -82,6 +86,8 @@ type TestAuthServerConfig struct { AuthPreferenceSpec *types.AuthPreferenceSpecV2 // Embedder is required to enable the assist in the auth server. Embedder embedding.Embedder + // CacheEnabled enables the primary auth server cache. + CacheEnabled bool } // CheckAndSetDefaults checks and sets defaults @@ -294,6 +300,19 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { return nil, trace.Wrap(err) } + if cfg.CacheEnabled { + srv.AuthServer.Cache, err = accesspoint.NewAccessCache(accesspoint.AccessCacheConfig{ + Context: ctx, + Services: srv.AuthServer.Services, + Setup: cache.ForAuth, + CacheName: []string{teleport.ComponentAuth}, + Events: true, + }) + if err != nil { + return nil, trace.Wrap(err) + } + } + err = srv.AuthServer.SetClusterAuditConfig(ctx, types.DefaultClusterAuditConfig()) if err != nil { return nil, trace.Wrap(err) @@ -737,6 +756,7 @@ func NewTestTLSServer(cfg TestTLSServerConfig) (*TestTLSServer, error) { if err != nil { return nil, trace.Wrap(err) } + // Register TLS endpoint of the auth service tlsConfig, err := srv.Identity.TLSConfig(srv.AuthServer.CipherSuites) if err != nil { @@ -1230,3 +1250,37 @@ type noopEmbedder struct{} func (n noopEmbedder) ComputeEmbeddings(_ context.Context, _ []string) ([]embedding.Vector64, error) { return []embedding.Vector64{}, nil } + +// flushClt is the set of methods expected by the the flushCache helper. +type flushClt interface { + // GetNamespace returns namespace by name + GetNamespace(name string) (*types.Namespace, error) + // UpsertNamespace upserts namespace + UpsertNamespace(types.Namespace) error + // DeleteNamespace deletes namespace by name + DeleteNamespace(name string) error +} + +// flushCache is a helper for waiting until preceding changes have propagated to the +// cache during a test. this is useful for writing tests that may want to update backend +// state and then perform some operation that depends on the auth server knoowing that state. +// note that this is only intended for use with the memory backend, as this helper relies on the assumption that +// write events for different keys show up in the order in which the writes were performed, which +// is not necessarily true for all backends. +func flushCache(t *testing.T, clt flushClt) { + // the pattern of writing a resource and then waiting for it to appear + // works for any resource type (when using memory backend). we use namespaces + // here because namespaces are deprecated and therefore unlikely to interfer + // with tests. + name := strings.ReplaceAll(uuid.NewString(), "-", "") + defer clt.DeleteNamespace(name) + + ns, err := types.NewNamespace(name) + require.NoError(t, err) + + require.NoError(t, clt.UpsertNamespace(ns)) + require.Eventually(t, func() bool { + _, err := clt.GetNamespace(name) + return err == nil + }, time.Second*20, time.Millisecond*200) +} diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 5e92bd3d89f21..ccaf5edc7c446 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -4587,15 +4587,32 @@ func verifyJWTAWSOIDC(clock clockwork.Clock, clusterName string, pairs []*types. return nil, trace.NewAggregate(errs...) } +type testTLSServerOptions struct { + cacheEnabled bool +} + +type testTLSServerOption func(*testTLSServerOptions) + +func withCacheEnabled(enabled bool) testTLSServerOption { + return func(options *testTLSServerOptions) { + options.cacheEnabled = enabled + } +} + // newTestTLSServer is a helper that returns a *TestTLSServer with sensible // defaults for most tests that are exercising Auth Service RPCs. // // For more advanced use-cases, call NewTestAuthServer and NewTestTLSServer // to provide a more detailed configuration. -func newTestTLSServer(t testing.TB) *TestTLSServer { +func newTestTLSServer(t testing.TB, opts ...testTLSServerOption) *TestTLSServer { + var options testTLSServerOptions + for _, opt := range opts { + opt(&options) + } as, err := NewTestAuthServer(TestAuthServerConfig{ - Dir: t.TempDir(), - Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()), + Dir: t.TempDir(), + Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()), + CacheEnabled: options.cacheEnabled, }) require.NoError(t, err) diff --git a/lib/service/service.go b/lib/service/service.go index c716e9c391a11..cf8635bacd442 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -76,6 +76,7 @@ import ( "github.com/gravitational/teleport/lib/ai/embedding" "github.com/gravitational/teleport/lib/auditd" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/accesspoint" "github.com/gravitational/teleport/lib/auth/keygen" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/authz" @@ -85,7 +86,6 @@ import ( "github.com/gravitational/teleport/lib/backend/firestore" "github.com/gravitational/teleport/lib/backend/kubernetes" "github.com/gravitational/teleport/lib/backend/lite" - "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/backend/pgbk" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/cache" @@ -1780,12 +1780,12 @@ func (process *TeleportProcess) initAuthService() error { return nil } - cache, err := process.newAccessCache(accessCacheConfig{ - services: as.Services, - setup: cache.ForAuth, - cacheName: []string{teleport.ComponentAuth}, - events: true, - unstarted: true, + cache, err := process.newAccessCache(accesspoint.AccessCacheConfig{ + Services: as.Services, + Setup: cache.ForAuth, + CacheName: []string{teleport.ComponentAuth}, + Events: true, + Unstarted: true, }) if err != nil { return trace.Wrap(err) @@ -2147,94 +2147,14 @@ func (process *TeleportProcess) OnExit(serviceName string, callback func(interfa }) } -// accessCacheConfig contains -// configuration for access cache -type accessCacheConfig struct { - // services is a collection - // of services to use as a cache base - services services.Services - // setup is a function that takes - // cache configuration and modifies it - setup cache.SetupConfigFn - // cacheName is a cache name - cacheName []string - // events is true if cache should turn on events - events bool - // unstarted is true if the cache should not be started - unstarted bool -} - -func (c *accessCacheConfig) CheckAndSetDefaults() error { - if c.services == nil { - return trace.BadParameter("missing parameter services") - } - if c.setup == nil { - return trace.BadParameter("missing parameter setup") - } - if len(c.cacheName) == 0 { - return trace.BadParameter("missing parameter cacheName") - } - return nil -} - // newAccessCache returns new local cache access point -func (process *TeleportProcess) newAccessCache(cfg accessCacheConfig) (*cache.Cache, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - process.log.Debugf("Creating in-memory backend for %v.", cfg.cacheName) - mem, err := memory.New(memory.Config{ - Context: process.ExitContext(), - EventsOff: !cfg.events, - Mirror: true, - }) - if err != nil { - return nil, trace.Wrap(err) - } - reporter, err := backend.NewReporter(backend.ReporterConfig{ - Component: teleport.ComponentCache, - Backend: mem, - Tracer: process.TracingProvider.Tracer(teleport.ComponentCache), - }) - if err != nil { - return nil, trace.Wrap(err) - } +func (process *TeleportProcess) newAccessCache(cfg accesspoint.AccessCacheConfig) (*cache.Cache, error) { + cfg.Context = process.ExitContext() + cfg.ProcessID = process.id + cfg.TracingProvider = process.TracingProvider + cfg.MaxRetryPeriod = process.Config.CachePolicy.MaxRetryPeriod - return cache.New(cfg.setup(cache.Config{ - Context: process.ExitContext(), - Backend: reporter, - Events: cfg.services, - ClusterConfig: cfg.services, - Provisioner: cfg.services, - Trust: cfg.services, - Users: cfg.services, - Access: cfg.services, - DynamicAccess: cfg.services, - Presence: cfg.services, - Restrictions: cfg.services, - Apps: cfg.services, - Kubernetes: cfg.services, - DatabaseServices: cfg.services, - Databases: cfg.services, - AppSession: cfg.services, - SnowflakeSession: cfg.services, - SAMLIdPSession: cfg.services, - WindowsDesktops: cfg.services, - SAMLIdPServiceProviders: cfg.services, - UserGroups: cfg.services, - Okta: cfg.services.OktaClient(), - SecReports: cfg.services.SecReportsClient(), - UserLoginStates: cfg.services.UserLoginStateClient(), - Integrations: cfg.services, - DiscoveryConfigs: cfg.services.DiscoveryConfigClient(), - WebSession: cfg.services.WebSessions(), - WebToken: cfg.services.WebTokens(), - Component: teleport.Component(append(cfg.cacheName, process.id, teleport.ComponentCache)...), - MetricComponent: teleport.Component(append(cfg.cacheName, teleport.ComponentCache)...), - Tracer: process.TracingProvider.Tracer(teleport.ComponentCache), - MaxRetryPeriod: process.Config.CachePolicy.MaxRetryPeriod, - Unstarted: cfg.unstarted, - })) + return accesspoint.NewAccessCache(cfg) } // newLocalCacheForNode returns new instance of access point configured for a local proxy. @@ -2386,10 +2306,10 @@ func (process *TeleportProcess) newLocalCacheForWindowsDesktop(clt auth.ClientI, // NewLocalCache returns new instance of access point func (process *TeleportProcess) NewLocalCache(clt auth.ClientI, setupConfig cache.SetupConfigFn, cacheName []string) (*cache.Cache, error) { - return process.newAccessCache(accessCacheConfig{ - services: clt, - setup: setupConfig, - cacheName: cacheName, + return process.newAccessCache(accesspoint.AccessCacheConfig{ + Services: clt, + Setup: setupConfig, + CacheName: cacheName, }) }