package endpoints import ( "context" "fmt" "time" "github.com/andres-erbsen/clock" "github.com/sirupsen/logrus" "github.com/spiffe/spire/pkg/common/telemetry" server_telemetry "github.com/spiffe/spire/pkg/common/telemetry/server" "github.com/spiffe/spire/pkg/server/api" "github.com/spiffe/spire/pkg/server/authorizedentries" "github.com/spiffe/spire/pkg/server/datastore" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type registrationEntries struct { cache *authorizedentries.Cache clk clock.Clock ds datastore.DataStore log logrus.FieldLogger metrics telemetry.Metrics eventsBeforeFirst map[uint]struct{} firstEvent uint firstEventTime time.Time lastEvent uint eventTracker *eventTracker sqlTransactionTimeout time.Duration fetchEntries map[string]struct{} // metrics change detection skippedEntryEvents int lastCacheStats authorizedentries.CacheStats } func (a *registrationEntries) captureChangedEntries(ctx context.Context) error { // first, reset the entries we might fetch. a.fetchEntries = make(map[string]struct{}) if err := a.searchBeforeFirstEvent(ctx); err != nil { return err } a.selectPolledEvents(ctx) if err := a.scanForNewEvents(ctx); err != nil { return err } return nil } func (a *registrationEntries) searchBeforeFirstEvent(ctx context.Context) error { // First event detected, and startup was less than a transaction timout away. if !a.firstEventTime.IsZero() && a.clk.Now().Sub(a.firstEventTime) <= a.sqlTransactionTimeout { resp, err := a.ds.ListRegistrationEntryEvents(ctx, &datastore.ListRegistrationEntryEventsRequest{ LessThanEventID: a.firstEvent, }) if err != nil { return err } for _, event := range resp.Events { // if we have seen it before, don't reload it. if _, seen := a.eventsBeforeFirst[event.EventID]; !seen { a.fetchEntries[event.EntryID] = struct{}{} a.eventsBeforeFirst[event.EventID] = struct{}{} } } return nil } // zero out unused event tracker if len(a.eventsBeforeFirst) != 0 { a.eventsBeforeFirst = make(map[uint]struct{}) } return nil } func (a *registrationEntries) selectPolledEvents(ctx context.Context) { // check if the polled events have appeared out-of-order selectedEvents := a.eventTracker.SelectEvents() for _, eventID := range selectedEvents { log := a.log.WithField(telemetry.EventID, eventID) event, err := a.ds.FetchRegistrationEntryEvent(ctx, eventID) switch status.Code(err) { case codes.OK: case codes.NotFound: continue default: log.WithError(err).Errorf("Failed to fetch info about skipped event %d", eventID) continue } a.fetchEntries[event.EntryID] = struct{}{} a.eventTracker.StopTracking(eventID) } a.eventTracker.FreeEvents(selectedEvents) } func (a *registrationEntries) scanForNewEvents(ctx context.Context) error { // If we haven't seen an event, scan for all events; otherwise, scan from the last event. var resp *datastore.ListRegistrationEntryEventsResponse var err error if a.firstEventTime.IsZero() { resp, err = a.ds.ListRegistrationEntryEvents(ctx, &datastore.ListRegistrationEntryEventsRequest{}) } else { resp, err = a.ds.ListRegistrationEntryEvents(ctx, &datastore.ListRegistrationEntryEventsRequest{ GreaterThanEventID: a.lastEvent, }) } if err != nil { return err } for _, event := range resp.Events { // event time determines if we have seen the first event. if a.firstEventTime.IsZero() { a.firstEvent = event.EventID a.lastEvent = event.EventID a.fetchEntries[event.EntryID] = struct{}{} a.firstEventTime = a.clk.Now() continue } // track any skipped event ids, should they appear later. for skipped := a.lastEvent + 1; skipped < event.EventID; skipped++ { a.eventTracker.StartTracking(skipped) } // every event adds its entry to the entry fetch list. a.fetchEntries[event.EntryID] = struct{}{} a.lastEvent = event.EventID } return nil } func (a *registrationEntries) loadCache(ctx context.Context, pageSize int32) error { // Build the cache var token string for { resp, err := a.ds.ListRegistrationEntries(ctx, &datastore.ListRegistrationEntriesRequest{ DataConsistency: datastore.RequireCurrent, // preliminary loading should not be done via read-replicas Pagination: &datastore.Pagination{ Token: token, PageSize: pageSize, }, }) if err != nil { return fmt.Errorf("failed to list registration entries: %w", err) } token = resp.Pagination.Token if token == "" { break } entries, err := api.RegistrationEntriesToProto(resp.Entries) if err != nil { return fmt.Errorf("failed to convert registration entries: %w", err) } for _, entry := range entries { a.cache.UpdateEntry(entry) } } return nil } // buildRegistrationEntriesCache Fetches all registration entries and adds them to the cache func buildRegistrationEntriesCache(ctx context.Context, log logrus.FieldLogger, metrics telemetry.Metrics, ds datastore.DataStore, clk clock.Clock, cache *authorizedentries.Cache, pageSize int32, cacheReloadInterval, sqlTransactionTimeout time.Duration) (*registrationEntries, error) { pollPeriods := PollPeriods(cacheReloadInterval, sqlTransactionTimeout) registrationEntries := ®istrationEntries{ cache: cache, clk: clk, ds: ds, log: log, metrics: metrics, sqlTransactionTimeout: sqlTransactionTimeout, eventsBeforeFirst: make(map[uint]struct{}), fetchEntries: make(map[string]struct{}), eventTracker: NewEventTracker(pollPeriods), skippedEntryEvents: -1, lastCacheStats: authorizedentries.CacheStats{ AliasesByEntryID: -1, AliasesBySelector: -1, EntriesByEntryID: -1, EntriesByParentID: -1, }, } if err := registrationEntries.loadCache(ctx, pageSize); err != nil { return nil, err } if err := registrationEntries.updateCache(ctx); err != nil { return nil, err } return registrationEntries, nil } // updateCache Fetches all the events since the last time this function was running and updates // the cache with all the changes. func (a *registrationEntries) updateCache(ctx context.Context) error { if err := a.captureChangedEntries(ctx); err != nil { return err } if err := a.updateCachedEntries(ctx); err != nil { return err } a.emitMetrics() return nil } // updateCacheEntry update/deletes/creates an individual registration entry in the cache. func (a *registrationEntries) updateCachedEntries(ctx context.Context) error { for entryId := range a.fetchEntries { commonEntry, err := a.ds.FetchRegistrationEntry(ctx, entryId) if err != nil { return err } if commonEntry == nil { a.cache.RemoveEntry(entryId) delete(a.fetchEntries, entryId) continue } entry, err := api.RegistrationEntryToProto(commonEntry) if err != nil { a.cache.RemoveEntry(entryId) delete(a.fetchEntries, entryId) a.log.WithField(telemetry.RegistrationID, entryId).Warn("Removed malformed registration entry from cache") continue } a.cache.UpdateEntry(entry) delete(a.fetchEntries, entryId) } return nil } func (a *registrationEntries) emitMetrics() { if a.skippedEntryEvents != int(a.eventTracker.EventCount()) { a.skippedEntryEvents = int(a.eventTracker.EventCount()) server_telemetry.SetSkippedEntryEventIDsCacheCountGauge(a.metrics, a.skippedEntryEvents) } cacheStats := a.cache.Stats() if a.lastCacheStats.AliasesByEntryID != cacheStats.AliasesByEntryID { a.lastCacheStats.AliasesByEntryID = cacheStats.AliasesByEntryID server_telemetry.SetNodeAliasesByEntryIDCacheCountGauge(a.metrics, a.lastCacheStats.AliasesByEntryID) } if a.lastCacheStats.AliasesBySelector != cacheStats.AliasesBySelector { a.lastCacheStats.AliasesBySelector = cacheStats.AliasesBySelector server_telemetry.SetNodeAliasesBySelectorCacheCountGauge(a.metrics, a.lastCacheStats.AliasesBySelector) } if a.lastCacheStats.EntriesByEntryID != cacheStats.EntriesByEntryID { a.lastCacheStats.EntriesByEntryID = cacheStats.EntriesByEntryID server_telemetry.SetEntriesByEntryIDCacheCountGauge(a.metrics, a.lastCacheStats.EntriesByEntryID) } if a.lastCacheStats.EntriesByParentID != cacheStats.EntriesByParentID { a.lastCacheStats.EntriesByParentID = cacheStats.EntriesByParentID server_telemetry.SetEntriesByParentIDCacheCountGauge(a.metrics, a.lastCacheStats.EntriesByParentID) } }