diff --git a/api/client/secreport/crud.go b/api/client/secreport/crud.go index a3f7ae28f2f06..e0f1bed87bc9c 100644 --- a/api/client/secreport/crud.go +++ b/api/client/secreport/crud.go @@ -115,36 +115,20 @@ func (c *Client) GetSecurityAuditQueryResult(ctx context.Context, resultID, next return resp, nil } -// GetSecurityReportsStates returns all security reports states. -func (c *Client) GetSecurityReportsStates(ctx context.Context) ([]*secreports.ReportState, error) { - return nil, trace.NotImplemented("GetSecurityReportsStates is not supported in the gRPC client") -} - // UpsertSecurityReportsState upserts security reports state. func (c *Client) UpsertSecurityReportsState(ctx context.Context, item *secreports.ReportState) error { return trace.NotImplemented("UpsertSecurityReportsState is not supported in the gRPC client") } -// DeleteSecurityReportsState deletes security reports state by name. -func (c *Client) DeleteSecurityReportsState(ctx context.Context, name string) error { - return trace.NotImplemented("DeleteSecurityReportsState is not supported in the gRPC client") -} - -// DeleteAllSecurityReportsStates deletes all security reports states. -func (c *Client) DeleteAllSecurityReportsStates(ctx context.Context) error { - return trace.NotImplemented("DeleteAllSecurityReportsStates is not supported in the gRPC client") -} - -// DeleteAllSecurityReports deletes all security reports. -func (c *Client) DeleteAllSecurityReports(ctx context.Context) error { - return trace.NotImplemented("DeleteAllSecurityReportsStates is not supported in the gRPC client") -} - -// DeleteAllSecurityAuditQueries deletes all security audit queries. -func (c *Client) DeleteAllSecurityAuditQueries(ctx context.Context) error { - return trace.NotImplemented("DeleteAllSecurityAuditQueries is not supported in the gRPC client") -} - func (c *Client) GetSecurityReportState(ctx context.Context, name string) (*secreports.ReportState, error) { - return nil, trace.NotImplemented("GetSecurityReportState is not supported in the gRPC client") + resp, err := c.grpcClient.GetReportState(ctx, &pb.GetReportStateRequest{Name: name}) + if err != nil { + return nil, trace.Wrap(err) + } + + out, err := v1.FromProtoReportState(resp) + if err != nil { + return nil, trace.Wrap(err) + } + return out, nil } diff --git a/api/types/header/header.go b/api/types/header/header.go index 37b3fb07be67d..7b8b6dbf19960 100644 --- a/api/types/header/header.go +++ b/api/types/header/header.go @@ -17,6 +17,7 @@ limitations under the License. package header import ( + "maps" "slices" "time" @@ -50,6 +51,19 @@ type ResourceHeader struct { Metadata Metadata `json:"metadata,omitempty"` } +func (h *ResourceHeader) Clone() *ResourceHeader { + if h == nil { + return nil + } + + return &ResourceHeader{ + Kind: h.Kind, + SubKind: h.SubKind, + Version: h.Version, + Metadata: *h.Metadata.Clone(), + } +} + // GetVersion returns the resource version. func (h *ResourceHeader) GetVersion() string { return h.Version @@ -179,6 +193,20 @@ type Metadata struct { Revision string `json:"revision,omitempty" yaml:"revision,omitempty"` } +func (m *Metadata) Clone() *Metadata { + if m == nil { + return nil + } + + return &Metadata{ + Name: m.Name, + Description: m.Description, + Labels: maps.Clone(m.Labels), + Expires: m.Expires, + Revision: m.Revision, + } +} + // GetRevision returns the revision func (m *Metadata) GetRevision() string { return m.Revision diff --git a/api/types/secreports/secreports.go b/api/types/secreports/secreports.go index e391be6804878..82be4fd64c431 100644 --- a/api/types/secreports/secreports.go +++ b/api/types/secreports/secreports.go @@ -35,6 +35,16 @@ type Report struct { Spec ReportSpec `json:"spec" yaml:"spec"` } +func (a *Report) Clone() *Report { + if a == nil { + return nil + } + return &Report{ + ResourceHeader: *a.ResourceHeader.Clone(), + Spec: *a.Spec.Clone(), + } +} + // ReportSpec is the security report spec. type ReportSpec struct { // Name is the Report name. @@ -49,6 +59,26 @@ type ReportSpec struct { Version string `json:"version,omitempty" yaml:"version,omitempty"` } +func (s *ReportSpec) Clone() *ReportSpec { + if s == nil { + return nil + } + var auditQueries []*AuditQuerySpec + if s.AuditQueries != nil { + auditQueries = make([]*AuditQuerySpec, 0, len(s.AuditQueries)) + for _, auditQuery := range s.AuditQueries { + auditQueries = append(auditQueries, auditQuery.Clone()) + } + } + return &ReportSpec{ + Name: s.Name, + Title: s.Title, + Description: s.Description, + AuditQueries: auditQueries, + Version: s.Version, + } +} + // AuditQuery is the audit query resource. type AuditQuery struct { // ResourceHeader is the resource header. @@ -57,6 +87,16 @@ type AuditQuery struct { Spec AuditQuerySpec `json:"spec" yaml:"spec"` } +func (a *AuditQuery) Clone() *AuditQuery { + if a == nil { + return nil + } + return &AuditQuery{ + ResourceHeader: *a.ResourceHeader.Clone(), + Spec: *a.Spec.Clone(), + } +} + // AuditQuerySpec is the audit query specification. type AuditQuerySpec struct { // Name is the AuditQuery name. @@ -69,6 +109,18 @@ type AuditQuerySpec struct { Query string `json:"query,omitempty" yaml:"query,omitempty"` } +func (s *AuditQuerySpec) Clone() *AuditQuerySpec { + if s == nil { + return nil + } + return &AuditQuerySpec{ + Name: s.Name, + Title: s.Title, + Description: s.Description, + Query: s.Query, + } +} + // CheckAndSetDefaults validates fields and populates empty fields with default values. func (a *AuditQuery) CheckAndSetDefaults() error { a.SetKind(types.KindAuditQuery) @@ -147,6 +199,17 @@ type ReportState struct { Spec ReportStateSpec `json:"spec,omitempty" yaml:"spec,omitempty"` } +func (a *ReportState) Clone() *ReportState { + if a == nil { + return nil + } + + return &ReportState{ + ResourceHeader: *a.ResourceHeader.Clone(), + Spec: *a.Spec.Clone(), + } +} + // ReportStateSpec is the security report state specification. type ReportStateSpec struct { // Name is the Report name. @@ -155,6 +218,16 @@ type ReportStateSpec struct { UpdatedAt time.Time `json:"updated_at,omitempty" yaml:"updated_at,omitempty"` } +func (s *ReportStateSpec) Clone() *ReportStateSpec { + if s == nil { + return nil + } + return &ReportStateSpec{ + Status: s.Status, + UpdatedAt: s.UpdatedAt, + } +} + // GetMetadata returns metadata. This is specifically for conforming to the Resource interface, // and should be removed when possible. func (a *ReportState) GetMetadata() types.Metadata { diff --git a/lib/cache/cache.go b/lib/cache/cache.go index fafb42b8dad5e..df842bab8e180 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -41,7 +41,6 @@ import ( "github.com/gravitational/teleport/api/internalutils/stream" apitracing "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/secreports" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/backend" @@ -497,9 +496,6 @@ type Cache struct { // cancel triggers exit context closure cancel context.CancelFunc - // legacyCacheCollections is a registry of resource legacyCollections - legacyCacheCollections *legacyCollections - // collections is a registry of resource collections. collections *collections @@ -511,7 +507,6 @@ type Cache struct { // regularly called methods. fnCache *utils.FnCache - secReportsCache services.SecReports eventsFanout *services.FanoutV2 lowVolumeEventsFanout *utils.RoundRobin[*services.FanoutV2] @@ -547,15 +542,6 @@ func (c *Cache) setReadStatus(ok bool, confirmedKinds map[resourceKind]types.Wat c.confirmedKinds = confirmedKinds } -// readLegacyCollectionCache acquires the cache read lock and uses getReader() to select the appropriate target for read -// operations on resources of the specified collection. The returned guard *must* be released to prevent deadlocks. -func readLegacyCollectionCache[R any](cache *Cache, collection collectionReader[R]) (legacyReadGuard[R], error) { - if collection == nil { - return legacyReadGuard[R]{}, trace.BadParameter("cannot read from an uninitialized cache collection") - } - return legacyReadCache(cache, collection.watchKind(), collection.getReader) -} - // acquireReadGuard provides a readGuard that may be used to determine how // a cache read should operate. The returned guard *must* be released to prevent deadlocks. func acquireReadGuard[T any, I comparable](cache *Cache, c *collection[T, I]) (readGuard[T, I], error) { @@ -580,55 +566,8 @@ func acquireReadGuard[T any, I comparable](cache *Cache, c *collection[T, I]) (r }, nil } -// legacyReadCache acquires the cache read lock and uses getReader() to select the appropriate target for read operations -// on resources of the specified kind. The returned guard *must* be released to prevent deadlocks. -func legacyReadCache[R any](cache *Cache, kind types.WatchKind, getReader func(cacheOK bool) R) (legacyReadGuard[R], error) { - if cache.closed.Load() { - return legacyReadGuard[R]{}, trace.Errorf("cache is closed") - } - cache.rw.RLock() - - if cache.ok { - if _, kindOK := cache.confirmedKinds[resourceKind{kind: kind.Kind, subkind: kind.SubKind}]; kindOK { - return legacyReadGuard[R]{ - reader: getReader(true), - release: cache.rw.RUnlock, - }, nil - } - } - - cache.rw.RUnlock() - return legacyReadGuard[R]{ - reader: getReader(false), - release: nil, - }, nil -} - -// legacyReadGuard holds a reference to a read-only "backend" R. If the referenced backed is the cache, then legacyReadGuard -// also holds the release function for the read lock, and ensures that it is not double-called. -type legacyReadGuard[R any] struct { - reader R - once sync.Once - release func() -} - -// Release releases the read lock if it is held. This method -// can be called multiple times. -func (r *legacyReadGuard[R]) Release() { - r.once.Do(func() { - if r.release == nil { - return - } - - r.release() - }) -} - -// IsCacheRead checks if this readGuard holds a cache reference. -func (r *legacyReadGuard[R]) IsCacheRead() bool { - return r.release != nil -} - +// readGuard holds a reference to a read-only "collection" T. If the referenced resource is in the cache, +// then readGuard also holds the release function for the read lock, and ensures that it is not double-called. type readGuard[T any, I comparable] struct { cacheRead bool store *store[T, I] @@ -636,6 +575,7 @@ type readGuard[T any, I comparable] struct { release func() } +// ReadCache checks if this readGuard holds a cache reference. func (r *readGuard[T, I]) ReadCache() bool { return r.cacheRead } @@ -906,26 +846,26 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - secReportsCache, err := local.NewSecReportsService(config.Backend, config.Clock) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - fanout := services.NewFanoutV2(services.FanoutV2Config{}) lowVolumeFanouts := make([]*services.FanoutV2, 0, config.FanoutShards) for i := 0; i < config.FanoutShards; i++ { lowVolumeFanouts = append(lowVolumeFanouts, services.NewFanoutV2(services.FanoutV2Config{})) } + collections, err := setupCollections(config) + if err != nil { + cancel() + return nil, trace.Wrap(err) + } + cs := &Cache{ ctx: ctx, cancel: cancel, Config: config, initC: make(chan struct{}), fnCache: fnCache, - secReportsCache: secReportsCache, eventsFanout: fanout, + collections: collections, lowVolumeEventsFanout: utils.NewRoundRobin(lowVolumeFanouts), Logger: slog.With( teleport.ComponentKey, config.Component, @@ -933,20 +873,6 @@ func New(config Config) (*Cache, error) { ), } - legacyCollections, err := setupLegacyCollections(cs, config.Watches) - if err != nil { - cs.Close() - return nil, trace.Wrap(err) - } - cs.legacyCacheCollections = legacyCollections - - collections, err := setupCollections(config, legacyCollections.byKind) - if err != nil { - cs.Close() - return nil, trace.Wrap(err) - } - cs.collections = collections - if config.Unstarted { return cs, nil } @@ -1561,33 +1487,8 @@ func (c *Cache) fetch(ctx context.Context, confirmedKinds map[resourceKind]types g, ctx := errgroup.WithContext(ctx) g.SetLimit(fetchLimit(c.target)) - applyfns := make([]applyFn, len(c.legacyCacheCollections.byKind)+len(c.collections.byKind)) + applyfns := make([]applyFn, len(c.collections.byKind)) i := 0 - for kind, collection := range c.legacyCacheCollections.byKind { - kind, collection := kind, collection - ii := i - i++ - - g.Go(func() (err error) { - ctx, span := c.Tracer.Start( - ctx, - fmt.Sprintf("cache/fetch/%s", kind.String()), - oteltrace.WithAttributes( - attribute.String("target", c.target), - ), - ) - defer func() { apitracing.EndSpan(span, err) }() - - _, cacheOK := confirmedKinds[resourceKind{kind: kind.kind, subkind: kind.subkind}] - applyfn, err := collection.fetch(ctx, cacheOK) - if err != nil { - return trace.Wrap(err, "failed to fetch resource: %q", kind) - } - - applyfns[ii] = tracedApplyFn(fetchSpan, c.Tracer, kind, applyfn) - return nil - }) - } for kind, handler := range c.collections.byKind { ii := i @@ -1634,14 +1535,9 @@ func (c *Cache) fetch(ctx context.Context, confirmedKinds map[resourceKind]types func (c *Cache) processEvent(ctx context.Context, event types.Event) error { resourceKind := resourceKindFromResource(event.Resource) - legacyCollection, legacyFound := c.legacyCacheCollections.byKind[resourceKind] handler, handlerFound := c.collections.byKind[resourceKind] switch { - case legacyFound: - if err := legacyCollection.processEvent(ctx, event); err != nil { - return trace.Wrap(err) - } case handlerFound: switch event.Type { case types.OpDelete: @@ -1670,123 +1566,6 @@ func (c *Cache) processEvent(ctx context.Context, event types.Event) error { return nil } -// GetSecurityAuditQuery returns the specified audit query resource. -func (c *Cache) GetSecurityAuditQuery(ctx context.Context, name string) (*secreports.AuditQuery, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityAuditQuery") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.auditQueries) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSecurityAuditQuery(ctx, name) -} - -// GetSecurityAuditQueries returns a list of all audit query resources. -func (c *Cache) GetSecurityAuditQueries(ctx context.Context) ([]*secreports.AuditQuery, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityAuditQueries") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.auditQueries) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSecurityAuditQueries(ctx) -} - -// ListSecurityAuditQueries returns a paginated list of all audit query resources. -func (c *Cache) ListSecurityAuditQueries(ctx context.Context, pageSize int, nextKey string) ([]*secreports.AuditQuery, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListSecurityAuditQueries") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.auditQueries) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListSecurityAuditQueries(ctx, pageSize, nextKey) -} - -// GetSecurityReport returns the specified security report resource. -func (c *Cache) GetSecurityReport(ctx context.Context, name string) (*secreports.Report, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReport") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.secReports) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSecurityReport(ctx, name) -} - -// GetSecurityReports returns a list of all security report resources. -func (c *Cache) GetSecurityReports(ctx context.Context) ([]*secreports.Report, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReports") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.secReports) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSecurityReports(ctx) -} - -// ListSecurityReports returns a paginated list of all security report resources. -func (c *Cache) ListSecurityReports(ctx context.Context, pageSize int, nextKey string) ([]*secreports.Report, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListSecurityReports") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.secReports) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListSecurityReports(ctx, pageSize, nextKey) -} - -// GetSecurityReportState returns the specified security report state resource. -func (c *Cache) GetSecurityReportState(ctx context.Context, name string) (*secreports.ReportState, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReportState") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.secReportsStates) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSecurityReportState(ctx, name) -} - -// GetSecurityReportsStates returns a list of all security report resources. -func (c *Cache) GetSecurityReportsStates(ctx context.Context) ([]*secreports.ReportState, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReportsStates") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.secReportsStates) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetSecurityReportsStates(ctx) -} - -// ListSecurityReportsStates returns a paginated list of all security report resources. -func (c *Cache) ListSecurityReportsStates(ctx context.Context, pageSize int, nextKey string) ([]*secreports.ReportState, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListSecurityReportsStates") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.secReportsStates) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListSecurityReportsStates(ctx, pageSize, nextKey) -} - // ListResources is a part of auth.Cache implementation func (c *Cache) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListResources") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index b8ff6d57d4878..800df9d00fd0a 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1248,87 +1248,6 @@ func newUserTasks(t *testing.T) *usertasksv1.UserTask { return ut } -// TestAuditQuery tests that CRUD operations on access list rule resources are -// replicated from the backend to the cache. -func TestAuditQuery(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[*secreports.AuditQuery]{ - newResource: func(name string) (*secreports.AuditQuery, error) { - return newAuditQuery(t, name), nil - }, - create: func(ctx context.Context, item *secreports.AuditQuery) error { - err := p.secReports.UpsertSecurityAuditQuery(ctx, item) - return trace.Wrap(err) - }, - list: p.secReports.GetSecurityAuditQueries, - cacheGet: p.cache.GetSecurityAuditQuery, - cacheList: p.cache.GetSecurityAuditQueries, - update: func(ctx context.Context, item *secreports.AuditQuery) error { - err := p.secReports.UpsertSecurityAuditQuery(ctx, item) - return trace.Wrap(err) - }, - deleteAll: p.secReports.DeleteAllSecurityAuditQueries, - }) -} - -// TestSecurityReportState tests that CRUD operations on security report state resources are -// replicated from the backend to the cache. -func TestSecurityReports(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[*secreports.Report]{ - newResource: func(name string) (*secreports.Report, error) { - return newSecurityReport(t, name), nil - }, - create: func(ctx context.Context, item *secreports.Report) error { - err := p.secReports.UpsertSecurityReport(ctx, item) - return trace.Wrap(err) - }, - list: p.secReports.GetSecurityReports, - cacheGet: p.cache.GetSecurityReport, - cacheList: p.cache.GetSecurityReports, - update: func(ctx context.Context, item *secreports.Report) error { - err := p.secReports.UpsertSecurityReport(ctx, item) - return trace.Wrap(err) - }, - deleteAll: p.secReports.DeleteAllSecurityReports, - }) -} - -// TestSecurityReportState tests that CRUD operations on security report state resources are -// replicated from the backend to the cache. -func TestSecurityReportState(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[*secreports.ReportState]{ - newResource: func(name string) (*secreports.ReportState, error) { - return newSecurityReportState(t, name), nil - }, - create: func(ctx context.Context, item *secreports.ReportState) error { - err := p.secReports.UpsertSecurityReportsState(ctx, item) - return trace.Wrap(err) - }, - list: p.secReports.GetSecurityReportsStates, - cacheGet: p.cache.GetSecurityReportState, - cacheList: p.cache.GetSecurityReportsStates, - update: func(ctx context.Context, item *secreports.ReportState) error { - err := p.secReports.UpsertSecurityReportsState(ctx, item) - return trace.Wrap(err) - }, - deleteAll: p.secReports.DeleteAllSecurityReportsStates, - }) -} - // testResources is a generic tester for resources. func testResources[T types.Resource](t *testing.T, p *testPack, funcs testFuncs[T]) { ctx := context.Background() diff --git a/lib/cache/collections.go b/lib/cache/collections.go index aa4dffc6a007a..fc37132af1e1c 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -18,6 +18,7 @@ package cache import ( "context" + "fmt" "github.com/gravitational/trace" @@ -38,6 +39,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" "github.com/gravitational/teleport/api/types/discoveryconfig" + "github.com/gravitational/teleport/api/types/secreports" "github.com/gravitational/teleport/api/types/userloginstate" ) @@ -132,6 +134,9 @@ type collections struct { discoveryConfigs *collection[*discoveryconfig.DiscoveryConfig, discoveryConfigIndex] provisioningStates *collection[*provisioningv1.PrincipalState, principalStateIndex] identityCenterPrincipalAssignments *collection[*identitycenterv1.PrincipalAssignment, identityCenterPrincipalAssignmentIndex] + auditQueries *collection[*secreports.AuditQuery, auditQueryIndex] + secReports *collection[*secreports.Report, securityReportIndex] + secReportsStates *collection[*secreports.ReportState, securityReportStateIndex] } // isKnownUncollectedKind is true if a resource kind is not stored in @@ -149,7 +154,7 @@ func isKnownUncollectedKind(kind string) bool { // setupCollections ensures that the appropriate [collection] is // initialized for all provided [types.WatcKind]s. An error is // returned if a [types.WatchKind] has no associated [collection]. -func setupCollections(c Config, legacyCollections map[resourceKind]legacyCollection) (*collections, error) { +func setupCollections(c Config) (*collections, error) { out := &collections{ byKind: make(map[resourceKind]collectionHandler, 1), } @@ -702,14 +707,78 @@ func setupCollections(c Config, legacyCollections map[resourceKind]legacyCollect out.identityCenterPrincipalAssignments = collect out.byKind[resourceKind] = out.identityCenterPrincipalAssignments + case types.KindAuditQuery: + collect, err := newAuditQueryCollection(c.SecReports, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.auditQueries = collect + out.byKind[resourceKind] = out.auditQueries + case types.KindSecurityReport: + collect, err := newSecurityReportCollection(c.SecReports, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.secReports = collect + out.byKind[resourceKind] = out.secReports + case types.KindSecurityReportState: + collect, err := newSecurityReportStateCollection(c.SecReports, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.secReportsStates = collect + out.byKind[resourceKind] = out.secReportsStates default: - _, legacyOk := legacyCollections[resourceKind] - if _, ok := out.byKind[resourceKind]; !ok && !legacyOk { + if _, ok := out.byKind[resourceKind]; !ok { return nil, trace.BadParameter("resource %q is not supported", watch.Kind) } } - } return out, nil } + +func resourceKindFromWatchKind(wk types.WatchKind) resourceKind { + switch wk.Kind { + case types.KindWebSession: + // Web sessions use subkind to differentiate between + // the types of sessions + return resourceKind{ + kind: wk.Kind, + subkind: wk.SubKind, + } + } + return resourceKind{ + kind: wk.Kind, + } +} + +func resourceKindFromResource(res types.Resource) resourceKind { + switch res.GetKind() { + case types.KindWebSession: + // Web sessions use subkind to differentiate between + // the types of sessions + return resourceKind{ + kind: res.GetKind(), + subkind: res.GetSubKind(), + } + } + return resourceKind{ + kind: res.GetKind(), + } +} + +type resourceKind struct { + kind string + subkind string +} + +func (r resourceKind) String() string { + if r.subkind == "" { + return r.kind + } + return fmt.Sprintf("%s/%s", r.kind, r.subkind) +} diff --git a/lib/cache/generic_legacy_collection.go b/lib/cache/generic_legacy_collection.go deleted file mode 100644 index 071a51ca33ec2..0000000000000 --- a/lib/cache/generic_legacy_collection.go +++ /dev/null @@ -1,124 +0,0 @@ -// Teleport -// Copyright (C) 2024 Gravitational, Inc. -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package cache - -import ( - "context" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/api/types" -) - -// genericCollection is a generic collection implementation for resource type T with collection-specific logic -// encapsulated in executor type E. Type R provides getter methods related to the collection, e.g. GetNodes(), -// GetRoles(). -type genericCollection[T any, R any, E executor[T, R]] struct { - cache *Cache - watch types.WatchKind - exec E -} - -// fetch implements collection -func (g *genericCollection[T, R, _]) fetch(ctx context.Context, cacheOK bool) (apply func(ctx context.Context) error, err error) { - // Singleton objects will only get deleted or updated, not both - deleteSingleton := false - - var resources []T - if cacheOK { - resources, err = g.exec.getAll(ctx, g.cache, g.watch.LoadSecrets) - if err != nil { - if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - deleteSingleton = true - } - } - - return func(ctx context.Context) error { - // Always perform the delete if this is not a singleton, otherwise - // only perform the delete if the singleton wasn't found - // or the resource kind isn't cached in the current generation. - if !g.exec.isSingleton() || deleteSingleton || !cacheOK { - if err := g.exec.deleteAll(ctx, g.cache); err != nil { - if !trace.IsNotFound(err) { - return trace.Wrap(err) - } - } - } - // If this is a singleton and we performed a deletion, return here - // because we only want to update or delete a singleton, not both. - // Also don't continue if the resource kind isn't cached in the current generation. - if g.exec.isSingleton() && deleteSingleton || !cacheOK { - return nil - } - for _, resource := range resources { - if err := g.exec.upsert(ctx, g.cache, resource); err != nil { - return trace.Wrap(err) - } - } - return nil - }, nil -} - -// processEvent implements collection -func (g *genericCollection[T, R, _]) processEvent(ctx context.Context, event types.Event) error { - switch event.Type { - case types.OpDelete: - if err := g.exec.delete(ctx, g.cache, event.Resource); err != nil { - if !trace.IsNotFound(err) { - g.cache.Logger.WarnContext(ctx, "Failed to delete resource", "error", err) - return trace.Wrap(err) - } - } - case types.OpPut: - var resource T - var ok bool - switch r := event.Resource.(type) { - case interface{ UnwrapT() T }: - resource, ok = r.UnwrapT(), true - default: - resource, ok = event.Resource.(T) - } - - if !ok { - return trace.BadParameter("unexpected type %T (expected %T)", event.Resource, resource) - } - - if err := g.exec.upsert(ctx, g.cache, resource); err != nil { - return trace.Wrap(err) - } - default: - g.cache.Logger.WarnContext(ctx, "Skipping unsupported event type", "event", event.Type) - } - return nil -} - -// watchKind implements collection -func (g *genericCollection[T, R, _]) watchKind() types.WatchKind { - return g.watch -} - -var _ legacyCollection = (*genericCollection[types.Resource, any, executor[types.Resource, any]])(nil) - -// genericCollection obtains the reader object from the executor based on the provided health status of the cache. -// Note that cacheOK set to true means that cache is overall healthy and the collection was confirmed as supported. -func (c *genericCollection[T, R, _]) getReader(cacheOK bool) R { - return c.exec.getReader(c.cache, cacheOK) -} - -var _ collectionReader[any] = (*genericCollection[types.Resource, any, executor[types.Resource, any]])(nil) diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go deleted file mode 100644 index 429c7ad6db634..0000000000000 --- a/lib/cache/legacy_collections.go +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -//nolint:unused // Because the executors generate a large amount of false positives. -package cache - -import ( - "context" - "fmt" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/api/client/proto" - usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/secreports" - "github.com/gravitational/teleport/lib/services" -) - -// legacyCollection is responsible for managing collection -// of resources updates -type legacyCollection interface { - // fetch fetches resources and returns a function which will apply said resources to the cache. - // fetch *must* not mutate cache state outside of the apply function. - // The provided cacheOK flag indicates whether this collection will be included in the cache generation that is - // being prepared. If cacheOK is false, fetch shouldn't fetch any resources, but the apply function that it - // returns must still delete resources from the backend. - fetch(ctx context.Context, cacheOK bool) (apply func(ctx context.Context) error, err error) - // processEvent processes event - processEvent(ctx context.Context, e types.Event) error - // watchKind returns a watch - // required for this collection - watchKind() types.WatchKind -} - -// executor[T, R] is a specific way to run the collector operations that we need -// for the genericCollector for a generic resource type T and its reader type R. -type executor[T any, R any] interface { - // getAll returns all of the target resources from the auth server. - // For singleton objects, this should be a size-1 slice. - getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]T, error) - - // upsert will create or update a target resource in the cache. - upsert(ctx context.Context, cache *Cache, value T) error - - // deleteAll will delete all target resources of the type in the cache. - deleteAll(ctx context.Context, cache *Cache) error - - // delete will delete a single target resource from the cache. For - // singletons, this is usually an alias to deleteAll. - delete(ctx context.Context, cache *Cache, resource types.Resource) error - - // isSingleton will return true if the target resource is a singleton. - isSingleton() bool - - // getReader returns the appropriate reader type R based on the health status of the cache. - // Reader type R provides getter methods related to the collection, e.g. GetNodes(), GetRoles(). - // Note that cacheOK set to true means that cache is overall healthy and the collection was confirmed as supported. - getReader(c *Cache, cacheOK bool) R -} - -// noReader is returned by getReader for resources which aren't directly used by the cache, and therefore have no associated reader. -type noReader struct{} - -type userTasksGetter interface { - ListUserTasks(ctx context.Context, pageSize int64, nextToken string, filters *usertasksv1.ListUserTasksFilters) ([]*usertasksv1.UserTask, string, error) - GetUserTask(ctx context.Context, name string) (*usertasksv1.UserTask, error) -} - -// legacyCollections is a registry of resource collections used by Cache. -type legacyCollections struct { - // byKind is a map of registered collections by resource Kind/SubKind - byKind map[resourceKind]legacyCollection - - auditQueries collectionReader[services.SecurityAuditQueryGetter] - secReports collectionReader[services.SecurityReportGetter] - secReportsStates collectionReader[services.SecurityReportStateGetter] -} - -// setupLegacyCollections returns a registry of legacyCollections. -func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollections, error) { - collections := &legacyCollections{ - byKind: make(map[resourceKind]legacyCollection, len(watches)), - } - for _, watch := range watches { - resourceKind := resourceKindFromWatchKind(watch) - switch watch.Kind { - case types.KindAuditQuery: - if c.SecReports == nil { - return nil, trace.BadParameter("missing parameter SecReports") - } - collections.auditQueries = &genericCollection[*secreports.AuditQuery, services.SecurityAuditQueryGetter, auditQueryExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.auditQueries - case types.KindSecurityReport: - if c.SecReports == nil { - return nil, trace.BadParameter("missing parameter KindSecurityReport") - } - collections.secReports = &genericCollection[*secreports.Report, services.SecurityReportGetter, secReportExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.secReports - case types.KindSecurityReportState: - if c.SecReports == nil { - return nil, trace.BadParameter("missing parameter KindSecurityReport") - } - collections.secReportsStates = &genericCollection[*secreports.ReportState, services.SecurityReportStateGetter, secReportStateExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.secReportsStates - } - } - return collections, nil -} - -func resourceKindFromWatchKind(wk types.WatchKind) resourceKind { - switch wk.Kind { - case types.KindWebSession: - // Web sessions use subkind to differentiate between - // the types of sessions - return resourceKind{ - kind: wk.Kind, - subkind: wk.SubKind, - } - } - return resourceKind{ - kind: wk.Kind, - } -} - -func resourceKindFromResource(res types.Resource) resourceKind { - switch res.GetKind() { - case types.KindWebSession: - // Web sessions use subkind to differentiate between - // the types of sessions - return resourceKind{ - kind: res.GetKind(), - subkind: res.GetSubKind(), - } - } - return resourceKind{ - kind: res.GetKind(), - } -} - -type resourceKind struct { - kind string - subkind string -} - -func (r resourceKind) String() string { - if r.subkind == "" { - return r.kind - } - return fmt.Sprintf("%s/%s", r.kind, r.subkind) -} - -// collectionReader extends the collection interface, adding routing capabilities. -type collectionReader[R any] interface { - legacyCollection - - // getReader returns the appropriate reader type T based on the health status of the cache. - // Reader type R provides getter methods related to the collection, e.g. GetNodes(), GetRoles(). - // Note that cacheOK set to true means that cache is overall healthy and the collection was confirmed as supported. - getReader(cacheOK bool) R -} - -type resourceGetter interface { - ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) -} - -type auditQueryExecutor struct{} - -func (auditQueryExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*secreports.AuditQuery, error) { - var out []*secreports.AuditQuery - var nextToken string - for { - var page []*secreports.AuditQuery - var err error - - page, nextToken, err = cache.secReportsCache.ListSecurityAuditQueries(ctx, 0 /* default page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - out = append(out, page...) - if nextToken == "" { - break - } - } - return out, nil -} - -func (auditQueryExecutor) upsert(ctx context.Context, cache *Cache, resource *secreports.AuditQuery) error { - err := cache.secReportsCache.UpsertSecurityAuditQuery(ctx, resource) - return trace.Wrap(err) -} - -func (auditQueryExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return trace.Wrap(cache.secReportsCache.DeleteAllSecurityReports(ctx)) -} - -func (auditQueryExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return trace.Wrap(cache.secReportsCache.DeleteSecurityAuditQuery(ctx, resource.GetName())) -} - -func (auditQueryExecutor) isSingleton() bool { return false } - -func (auditQueryExecutor) getReader(cache *Cache, cacheOK bool) services.SecurityAuditQueryGetter { - if cacheOK { - return cache.secReportsCache - } - return cache.Config.SecReports -} - -var _ executor[*secreports.AuditQuery, services.SecurityAuditQueryGetter] = auditQueryExecutor{} - -type secReportExecutor struct{} - -func (secReportExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*secreports.Report, error) { - var out []*secreports.Report - var nextToken string - for { - var page []*secreports.Report - var err error - - page, nextToken, err = cache.secReportsCache.ListSecurityReports(ctx, 0 /* default page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - out = append(out, page...) - if nextToken == "" { - break - } - } - return out, nil -} - -func (secReportExecutor) upsert(ctx context.Context, cache *Cache, resource *secreports.Report) error { - err := cache.secReportsCache.UpsertSecurityReport(ctx, resource) - return trace.Wrap(err) -} - -func (secReportExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return trace.Wrap(cache.secReportsCache.DeleteAllSecurityReports(ctx)) -} - -func (secReportExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return trace.Wrap(cache.secReportsCache.DeleteSecurityReport(ctx, resource.GetName())) -} - -func (secReportExecutor) isSingleton() bool { return false } - -func (secReportExecutor) getReader(cache *Cache, cacheOK bool) services.SecurityReportGetter { - if cacheOK { - return cache.secReportsCache - } - return cache.Config.SecReports -} - -var _ executor[*secreports.Report, services.SecurityReportGetter] = secReportExecutor{} - -type secReportStateExecutor struct{} - -func (secReportStateExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*secreports.ReportState, error) { - var out []*secreports.ReportState - var nextToken string - for { - var page []*secreports.ReportState - var err error - - page, nextToken, err = cache.secReportsCache.ListSecurityReportsStates(ctx, 0 /* default page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - out = append(out, page...) - if nextToken == "" { - break - } - } - return out, nil -} - -func (secReportStateExecutor) upsert(ctx context.Context, cache *Cache, resource *secreports.ReportState) error { - err := cache.secReportsCache.UpsertSecurityReportsState(ctx, resource) - return trace.Wrap(err) -} - -func (secReportStateExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return trace.Wrap(cache.secReportsCache.DeleteAllSecurityReportsStates(ctx)) -} - -func (secReportStateExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return trace.Wrap(cache.secReportsCache.DeleteSecurityReportsState(ctx, resource.GetName())) -} - -func (secReportStateExecutor) isSingleton() bool { return false } - -func (secReportStateExecutor) getReader(cache *Cache, cacheOK bool) services.SecurityReportStateGetter { - if cacheOK { - return cache.secReportsCache - } - return cache.Config.SecReports -} - -var _ executor[*secreports.ReportState, services.SecurityReportStateGetter] = secReportStateExecutor{} diff --git a/lib/cache/security_report.go b/lib/cache/security_report.go new file mode 100644 index 0000000000000..abfcba742384e --- /dev/null +++ b/lib/cache/security_report.go @@ -0,0 +1,339 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/header" + "github.com/gravitational/teleport/api/types/secreports" + "github.com/gravitational/teleport/lib/services" +) + +type auditQueryIndex string + +const auditQueryNameIndex auditQueryIndex = "name" + +func newAuditQueryCollection(upstream services.SecReports, w types.WatchKind) (*collection[*secreports.AuditQuery, auditQueryIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter SecReports") + } + + return &collection[*secreports.AuditQuery, auditQueryIndex]{ + store: newStore( + (*secreports.AuditQuery).Clone, + map[auditQueryIndex]func(*secreports.AuditQuery) string{ + auditQueryNameIndex: func(r *secreports.AuditQuery) string { + return r.GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*secreports.AuditQuery, error) { + var out []*secreports.AuditQuery + var nextToken string + for { + var page []*secreports.AuditQuery + var err error + + page, nextToken, err = upstream.ListSecurityAuditQueries(ctx, 0 /* default page size */, nextToken) + if err != nil { + // AccessDenied is returned if the cluster is not licensed for access monitoring. + if trace.IsAccessDenied(err) { + return nil, nil + } + return nil, trace.Wrap(err) + } + out = append(out, page...) + if nextToken == "" { + break + } + } + return out, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *secreports.AuditQuery { + return &secreports.AuditQuery{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// GetSecurityAuditQuery returns the specified audit query resource. +func (c *Cache) GetSecurityAuditQuery(ctx context.Context, name string) (*secreports.AuditQuery, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityAuditQuery") + defer span.End() + + getter := genericGetter[*secreports.AuditQuery, auditQueryIndex]{ + cache: c, + collection: c.collections.auditQueries, + index: auditQueryNameIndex, + upstreamGet: c.Config.SecReports.GetSecurityAuditQuery, + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} + +// GetSecurityAuditQueries returns a list of all audit query resources. +func (c *Cache) GetSecurityAuditQueries(ctx context.Context) ([]*secreports.AuditQuery, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityAuditQueries") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.auditQueries) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + out, err := c.Config.SecReports.GetSecurityAuditQueries(ctx) + return out, trace.Wrap(err) + } + + out := make([]*secreports.AuditQuery, 0, rg.store.len()) + for a := range rg.store.resources(auditQueryNameIndex, "", "") { + out = append(out, a.Clone()) + } + + return out, nil +} + +// ListSecurityAuditQueries returns a paginated list of all audit query resources. +func (c *Cache) ListSecurityAuditQueries(ctx context.Context, pageSize int, pageToken string) ([]*secreports.AuditQuery, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListSecurityAuditQueries") + defer span.End() + + lister := genericLister[*secreports.AuditQuery, auditQueryIndex]{ + cache: c, + collection: c.collections.auditQueries, + index: auditQueryNameIndex, + upstreamList: c.Config.SecReports.ListSecurityAuditQueries, + nextToken: func(a *secreports.AuditQuery) string { + return a.GetMetadata().Name + }, + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} + +type securityReportIndex string + +const securityReportNameIndex securityReportIndex = "name" + +func newSecurityReportCollection(upstream services.SecReports, w types.WatchKind) (*collection[*secreports.Report, securityReportIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter SecReports") + } + + return &collection[*secreports.Report, securityReportIndex]{ + store: newStore( + (*secreports.Report).Clone, + map[securityReportIndex]func(*secreports.Report) string{ + securityReportNameIndex: func(r *secreports.Report) string { + return r.GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*secreports.Report, error) { + var out []*secreports.Report + var nextToken string + for { + var page []*secreports.Report + var err error + + page, nextToken, err = upstream.ListSecurityReports(ctx, 0 /* default page size */, nextToken) + if err != nil { + // AccessDenied is returned if the cluster is not licensed for access monitoring. + if trace.IsAccessDenied(err) { + return nil, nil + } + + return nil, trace.Wrap(err) + } + out = append(out, page...) + if nextToken == "" { + break + } + } + return out, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *secreports.Report { + return &secreports.Report{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// GetSecurityReport returns the specified security report resource. +func (c *Cache) GetSecurityReport(ctx context.Context, name string) (*secreports.Report, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReport") + defer span.End() + + getter := genericGetter[*secreports.Report, securityReportIndex]{ + cache: c, + collection: c.collections.secReports, + index: securityReportNameIndex, + upstreamGet: c.Config.SecReports.GetSecurityReport, + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} + +// GetSecurityReports returns a list of all security report resources. +func (c *Cache) GetSecurityReports(ctx context.Context) ([]*secreports.Report, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReports") + defer span.End() + + rg, err := acquireReadGuard(c, c.collections.secReports) + if err != nil { + return nil, trace.Wrap(err) + } + defer rg.Release() + + if !rg.ReadCache() { + out, err := c.Config.SecReports.GetSecurityReports(ctx) + return out, trace.Wrap(err) + } + + out := make([]*secreports.Report, 0, rg.store.len()) + for r := range rg.store.resources(securityReportNameIndex, "", "") { + out = append(out, r.Clone()) + } + + return out, nil +} + +// ListSecurityReports returns a paginated list of all security report resources. +func (c *Cache) ListSecurityReports(ctx context.Context, pageSize int, pageToken string) ([]*secreports.Report, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListSecurityReports") + defer span.End() + + lister := genericLister[*secreports.Report, securityReportIndex]{ + cache: c, + collection: c.collections.secReports, + index: securityReportNameIndex, + upstreamList: c.Config.SecReports.ListSecurityReports, + nextToken: func(r *secreports.Report) string { + return r.GetMetadata().Name + }, + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} + +type securityReportStateIndex string + +const securityReportStateNameIndex securityReportStateIndex = "name" + +func newSecurityReportStateCollection(upstream services.SecReports, w types.WatchKind) (*collection[*secreports.ReportState, securityReportStateIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter SecReports") + } + + return &collection[*secreports.ReportState, securityReportStateIndex]{ + store: newStore( + (*secreports.ReportState).Clone, + map[securityReportStateIndex]func(*secreports.ReportState) string{ + securityReportStateNameIndex: func(r *secreports.ReportState) string { + return r.GetName() + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*secreports.ReportState, error) { + var out []*secreports.ReportState + var nextToken string + for { + var page []*secreports.ReportState + var err error + + page, nextToken, err = upstream.ListSecurityReportsStates(ctx, 0 /* default page size */, nextToken) + if err != nil { + // AccessDenied is returned if the cluster is not licensed for access monitoring. + if trace.IsAccessDenied(err) { + return nil, nil + } + + return nil, trace.Wrap(err) + } + out = append(out, page...) + if nextToken == "" { + break + } + } + return out, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *secreports.ReportState { + return &secreports.ReportState{ + ResourceHeader: header.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: header.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// GetSecurityReportState returns the specified security report state resource. +func (c *Cache) GetSecurityReportState(ctx context.Context, name string) (*secreports.ReportState, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetSecurityReportState") + defer span.End() + + getter := genericGetter[*secreports.ReportState, securityReportStateIndex]{ + cache: c, + collection: c.collections.secReportsStates, + index: securityReportStateNameIndex, + upstreamGet: c.Config.SecReports.GetSecurityReportState, + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} + +// ListSecurityReportsStates returns a paginated list of all security report resources. +func (c *Cache) ListSecurityReportsStates(ctx context.Context, pageSize int, pageToken string) ([]*secreports.ReportState, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListSecurityReportsStates") + defer span.End() + + lister := genericLister[*secreports.ReportState, securityReportStateIndex]{ + cache: c, + collection: c.collections.secReportsStates, + index: securityReportStateNameIndex, + upstreamList: c.Config.SecReports.ListSecurityReportsStates, + nextToken: func(r *secreports.ReportState) string { + return r.GetMetadata().Name + }, + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} diff --git a/lib/cache/sercurity_report_test.go b/lib/cache/sercurity_report_test.go new file mode 100644 index 0000000000000..d7d78f15f42bd --- /dev/null +++ b/lib/cache/sercurity_report_test.go @@ -0,0 +1,226 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types/secreports" +) + +// TestAuditQuery tests that CRUD operations on access list rule resources are +// replicated from the backend to the cache. +func TestAuditQuery(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + t.Run("GetSecurityAuditQueries", func(t *testing.T) { + testResources(t, p, testFuncs[*secreports.AuditQuery]{ + newResource: func(name string) (*secreports.AuditQuery, error) { + return newAuditQuery(t, name), nil + }, + create: func(ctx context.Context, item *secreports.AuditQuery) error { + err := p.secReports.UpsertSecurityAuditQuery(ctx, item) + return trace.Wrap(err) + }, + list: p.secReports.GetSecurityAuditQueries, + cacheGet: p.cache.GetSecurityAuditQuery, + cacheList: p.cache.GetSecurityAuditQueries, + update: func(ctx context.Context, item *secreports.AuditQuery) error { + err := p.secReports.UpsertSecurityAuditQuery(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.secReports.DeleteAllSecurityAuditQueries, + }) + }) + + t.Run("ListSecurityAuditQueries", func(t *testing.T) { + testResources(t, p, testFuncs[*secreports.AuditQuery]{ + newResource: func(name string) (*secreports.AuditQuery, error) { + return newAuditQuery(t, name), nil + }, + create: func(ctx context.Context, item *secreports.AuditQuery) error { + err := p.secReports.UpsertSecurityAuditQuery(ctx, item) + return trace.Wrap(err) + }, + list: p.secReports.GetSecurityAuditQueries, + cacheGet: p.cache.GetSecurityAuditQuery, + cacheList: func(ctx context.Context) ([]*secreports.AuditQuery, error) { + var out []*secreports.AuditQuery + var startKey string + + for { + resp, next, err := p.cache.ListSecurityAuditQueries(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + out = append(out, resp...) + startKey = next + if next == "" { + break + } + } + + return out, nil + }, + update: func(ctx context.Context, item *secreports.AuditQuery) error { + err := p.secReports.UpsertSecurityAuditQuery(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.secReports.DeleteAllSecurityAuditQueries, + }) + }) + +} + +// TestSecurityReportState tests that CRUD operations on security report state resources are +// replicated from the backend to the cache. +func TestSecurityReports(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + t.Run("GetSecurityReports", func(t *testing.T) { + testResources(t, p, testFuncs[*secreports.Report]{ + newResource: func(name string) (*secreports.Report, error) { + return newSecurityReport(t, name), nil + }, + create: func(ctx context.Context, item *secreports.Report) error { + err := p.secReports.UpsertSecurityReport(ctx, item) + return trace.Wrap(err) + }, + list: p.secReports.GetSecurityReports, + cacheGet: p.cache.GetSecurityReport, + cacheList: p.cache.GetSecurityReports, + update: func(ctx context.Context, item *secreports.Report) error { + err := p.secReports.UpsertSecurityReport(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.secReports.DeleteAllSecurityReports, + }) + }) + t.Run("ListSecurityReports", func(t *testing.T) { + testResources(t, p, testFuncs[*secreports.Report]{ + newResource: func(name string) (*secreports.Report, error) { + return newSecurityReport(t, name), nil + }, + create: func(ctx context.Context, item *secreports.Report) error { + err := p.secReports.UpsertSecurityReport(ctx, item) + return trace.Wrap(err) + }, + list: p.secReports.GetSecurityReports, + cacheGet: p.cache.GetSecurityReport, + cacheList: func(ctx context.Context) ([]*secreports.Report, error) { + var out []*secreports.Report + var startKey string + + for { + resp, next, err := p.cache.ListSecurityReports(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + out = append(out, resp...) + startKey = next + if next == "" { + break + } + } + + return out, nil + + }, + update: func(ctx context.Context, item *secreports.Report) error { + err := p.secReports.UpsertSecurityReport(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.secReports.DeleteAllSecurityReports, + }) + }) + +} + +// TestSecurityReportState tests that CRUD operations on security report state resources are +// replicated from the backend to the cache. +func TestSecurityReportState(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[*secreports.ReportState]{ + newResource: func(name string) (*secreports.ReportState, error) { + return newSecurityReportState(t, name), nil + }, + create: func(ctx context.Context, item *secreports.ReportState) error { + err := p.secReports.UpsertSecurityReportsState(ctx, item) + return trace.Wrap(err) + }, + list: func(ctx context.Context) ([]*secreports.ReportState, error) { + var out []*secreports.ReportState + var startKey string + for { + resp, next, err := p.secReports.ListSecurityReportsStates(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + out = append(out, resp...) + + if next == "" { + break + } + startKey = next + } + + return out, nil + }, + cacheGet: p.cache.GetSecurityReportState, + cacheList: func(ctx context.Context) ([]*secreports.ReportState, error) { + var out []*secreports.ReportState + var startKey string + for { + resp, next, err := p.cache.ListSecurityReportsStates(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + out = append(out, resp...) + + if next == "" { + break + } + startKey = next + } + + return out, nil + }, + update: func(ctx context.Context, item *secreports.ReportState) error { + err := p.secReports.UpsertSecurityReportsState(ctx, item) + return trace.Wrap(err) + }, + deleteAll: p.secReports.DeleteAllSecurityReportsStates, + }) + +} diff --git a/lib/services/local/secreports.go b/lib/services/local/secreports.go index 36307ce17d528..bd13f2868211c 100644 --- a/lib/services/local/secreports.go +++ b/lib/services/local/secreports.go @@ -122,12 +122,6 @@ func (s *SecReportsService) GetSecurityReports(ctx context.Context) ([]*secrepor return reports, trace.Wrap(err) } -// GetSecurityReportsStates returns security report states. -func (s *SecReportsService) GetSecurityReportsStates(ctx context.Context) ([]*secreports.ReportState, error) { - states, err := s.securityReportStateSvc.GetResources(ctx) - return states, trace.Wrap(err) -} - // GetSecurityAuditQuery returns audit query by name. func (s *SecReportsService) GetSecurityAuditQuery(ctx context.Context, name string) (*secreports.AuditQuery, error) { r, err := s.auditQuerySvc.GetResource(ctx, name) diff --git a/lib/services/secreports.go b/lib/services/secreports.go index 2689c05f6016d..935d6c3d6d8b6 100644 --- a/lib/services/secreports.go +++ b/lib/services/secreports.go @@ -51,8 +51,6 @@ type SecurityReportGetter interface { type SecurityReportStateGetter interface { // GetSecurityReportState returns a security report state. GetSecurityReportState(ctx context.Context, name string) (*secreports.ReportState, error) - // GetSecurityReportsStates returns security report states. - GetSecurityReportsStates(context.Context) ([]*secreports.ReportState, error) // ListSecurityReportsStates lists security report states. ListSecurityReportsStates(context.Context, int, string) ([]*secreports.ReportState, string, error) } @@ -64,24 +62,16 @@ type SecReports interface { UpsertSecurityAuditQuery(ctx context.Context, in *secreports.AuditQuery) error // DeleteSecurityAuditQuery deletes an audit query. DeleteSecurityAuditQuery(ctx context.Context, name string) error - // DeleteAllSecurityAuditQueries deletes all audit queries. - DeleteAllSecurityAuditQueries(context.Context) error SecurityReportGetter // UpsertSecurityReport upserts a security report. UpsertSecurityReport(ctx context.Context, item *secreports.Report) error // DeleteSecurityReport deletes a security report. DeleteSecurityReport(ctx context.Context, name string) error - // DeleteAllSecurityReports deletes all audit queries. - DeleteAllSecurityReports(context.Context) error SecurityReportStateGetter // UpsertSecurityReportsState upserts a security report state. UpsertSecurityReportsState(ctx context.Context, item *secreports.ReportState) error - // DeleteSecurityReportsState deletes all audit queries. - DeleteSecurityReportsState(ctx context.Context, name string) error - // DeleteAllSecurityReportsStates deletes all audit queries. - DeleteAllSecurityReportsStates(context.Context) error } // CostLimiter is the interface for the security cost limiter.