diff --git a/api/types/integration.go b/api/types/integration.go index 41131ad8c7b17..801abbd2b2f31 100644 --- a/api/types/integration.go +++ b/api/types/integration.go @@ -86,6 +86,8 @@ type Integration interface { GetCredentials() PluginCredentials // WithoutCredentials returns a copy without credentials. WithoutCredentials() Integration + // Clone returns a copy of the integration. + Clone() Integration } var _ ResourceWithLabels = (*IntegrationV1)(nil) @@ -605,6 +607,11 @@ func (ig *IntegrationV1) GetCredentials() PluginCredentials { return ig.Spec.Credentials } +// Clone returns a copy of the integration. +func (ig *IntegrationV1) Clone() Integration { + return utils.CloneProtoMsg(ig) +} + // WithoutCredentials returns a copy without credentials. func (ig *IntegrationV1) WithoutCredentials() Integration { if ig == nil || ig.GetCredentials() == nil { diff --git a/api/types/plugin_static_credentials.go b/api/types/plugin_static_credentials.go index e70df454c138f..ea03c4d69c9d1 100644 --- a/api/types/plugin_static_credentials.go +++ b/api/types/plugin_static_credentials.go @@ -16,7 +16,11 @@ limitations under the License. package types -import "github.com/gravitational/trace" +import ( + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils" +) // PluginStaticCredentials are static credentials for plugins. type PluginStaticCredentials interface { @@ -40,6 +44,8 @@ type PluginStaticCredentials interface { // GetSSHCertAuthorities will return the attached SSH CA keys. GetSSHCertAuthorities() []*SSHKeyPair + // Clone returns a copy of the credentials. + Clone() PluginStaticCredentials } // NewPluginStaticCredentials creates a new PluginStaticCredentialsV1 resource. @@ -58,6 +64,11 @@ func NewPluginStaticCredentials(metadata Metadata, spec PluginStaticCredentialsS return p, nil } +// Clone returns a copy of the credentials. +func (p *PluginStaticCredentialsV1) Clone() PluginStaticCredentials { + return utils.CloneProtoMsg(p) +} + // CheckAndSetDefaults checks validity of all parameters and sets defaults. func (p *PluginStaticCredentialsV1) CheckAndSetDefaults() error { p.setStaticFields() diff --git a/lib/cache/access_monitoring_rule.go b/lib/cache/access_monitoring_rule.go new file mode 100644 index 0000000000000..7bc6a518efdd6 --- /dev/null +++ b/lib/cache/access_monitoring_rule.go @@ -0,0 +1,132 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/services" +) + +type accessMonitoringRuleIndex string + +const accessMonitoringRuleNameIndex accessMonitoringRuleIndex = "name" + +func newAccessMonitoringRuleCollection(upstream services.AccessMonitoringRules, w types.WatchKind) (*collection[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter Integrations") + } + + return &collection[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex]{ + store: newStore(map[accessMonitoringRuleIndex]func(*accessmonitoringrulesv1.AccessMonitoringRule) string{ + accessMonitoringRuleNameIndex: func(r *accessmonitoringrulesv1.AccessMonitoringRule) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { + var resources []*accessmonitoringrulesv1.AccessMonitoringRule + var nextToken string + for { + var page []*accessmonitoringrulesv1.AccessMonitoringRule + var err error + page, nextToken, err = upstream.ListAccessMonitoringRules(ctx, 0 /* page size */, nextToken) + if err != nil { + return nil, trace.Wrap(err) + } + resources = append(resources, page...) + + if nextToken == "" { + break + } + } + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) *accessmonitoringrulesv1.AccessMonitoringRule { + return &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: &headerv1.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// ListAccessMonitoringRules returns a paginated list of access monitoring rules. +func (c *Cache) ListAccessMonitoringRules(ctx context.Context, pageSize int, pageToken string) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAccessMonitoringRules") + defer span.End() + + lister := genericLister[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex]{ + cache: c, + collection: c.collections.accessMonitoringRules, + index: accessMonitoringRuleNameIndex, + upstreamList: c.Config.AccessMonitoringRules.ListAccessMonitoringRules, + nextToken: func(t *accessmonitoringrulesv1.AccessMonitoringRule) string { + return t.GetMetadata().Name + }, + clone: utils.CloneProtoMsg[*accessmonitoringrulesv1.AccessMonitoringRule], + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} + +// ListAccessMonitoringRulesWithFilter returns a paginated list of access monitoring rules. +func (c *Cache) ListAccessMonitoringRulesWithFilter(ctx context.Context, req *accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListAccessMonitoringRules") + defer span.End() + + lister := genericLister[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex]{ + cache: c, + collection: c.collections.accessMonitoringRules, + index: accessMonitoringRuleNameIndex, + upstreamList: c.Config.AccessMonitoringRules.ListAccessMonitoringRules, + nextToken: func(t *accessmonitoringrulesv1.AccessMonitoringRule) string { + return t.GetMetadata().Name + }, + clone: utils.CloneProtoMsg[*accessmonitoringrulesv1.AccessMonitoringRule], + filter: func(rule *accessmonitoringrulesv1.AccessMonitoringRule) bool { + return services.MatchAccessMonitoringRule(rule, req.GetSubjects(), req.GetNotificationName(), req.GetAutomaticReviewName()) + }, + } + out, next, err := lister.list(ctx, int(req.GetPageSize()), req.GetPageToken()) + return out, next, trace.Wrap(err) +} + +// GetAccessMonitoringRule returns the specified AccessMonitoringRule resources. +func (c *Cache) GetAccessMonitoringRule(ctx context.Context, name string) (*accessmonitoringrulesv1.AccessMonitoringRule, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetAccessMonitoringRule") + defer span.End() + + getter := genericGetter[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex]{ + cache: c, + collection: c.collections.accessMonitoringRules, + index: accessMonitoringRuleNameIndex, + upstreamGet: c.Config.AccessMonitoringRules.GetAccessMonitoringRule, + clone: utils.CloneProtoMsg[*accessmonitoringrulesv1.AccessMonitoringRule], + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} diff --git a/lib/cache/access_monitoring_rule_test.go b/lib/cache/access_monitoring_rule_test.go new file mode 100644 index 0000000000000..929e5c0aa73c0 --- /dev/null +++ b/lib/cache/access_monitoring_rule_test.go @@ -0,0 +1,226 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + "github.com/gravitational/teleport/api/types" +) + +func TestAccessMonitoringRules(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources153(t, p, testFuncs153[*accessmonitoringrulesv1.AccessMonitoringRule]{ + newResource: func(name string) (*accessmonitoringrulesv1.AccessMonitoringRule, error) { + return newAccessMonitoringRule(t), nil + }, + create: func(ctx context.Context, i *accessmonitoringrulesv1.AccessMonitoringRule) error { + _, err := p.accessMonitoringRules.CreateAccessMonitoringRule(ctx, i) + return err + }, + list: func(ctx context.Context) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { + results, _, err := p.accessMonitoringRules.ListAccessMonitoringRules(ctx, 0, "") + return results, err + }, + cacheGet: p.cache.GetAccessMonitoringRule, + cacheList: func(ctx context.Context) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { + results, _, err := p.cache.ListAccessMonitoringRules(ctx, 0, "") + return results, err + }, + update: func(ctx context.Context, i *accessmonitoringrulesv1.AccessMonitoringRule) error { + _, err := p.accessMonitoringRules.UpdateAccessMonitoringRule(ctx, i) + return err + }, + deleteAll: p.accessMonitoringRules.DeleteAllAccessMonitoringRules, + }) +} + +func TestListAccessMonitoringRulesWithFilter(t *testing.T) { + t.Parallel() + + tests := []struct { + description string + rule *accessmonitoringrulesv1.AccessMonitoringRule + req *accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest + expectedRule bool + }{ + { + description: "filter by notification integration", + rule: &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example-notification-rule", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "someCondition", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "notificationIntegration", + }, + }, + }, + req: &accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest{ + Subjects: []string{types.KindAccessRequest}, + NotificationName: "notificationIntegration", + }, + expectedRule: true, + }, + { + description: "filter by automatic_review integration", + rule: &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example-automatic-approval-rule", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "someCondition", + AutomaticReview: &accessmonitoringrulesv1.AutomaticReview{ + Integration: "automaticReviewIntegration", + Decision: types.RequestState_APPROVED.String(), + }, + }, + }, + req: &accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest{ + Subjects: []string{types.KindAccessRequest}, + AutomaticReviewName: "automaticReviewIntegration", + }, + expectedRule: true, + }, + { + description: "filter by both notification and automatic_review integration", + rule: &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example-combined-rule", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "someCondition", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "notificationIntegration", + }, + AutomaticReview: &accessmonitoringrulesv1.AutomaticReview{ + Integration: "automaticReviewIntegration", + Decision: types.RequestState_APPROVED.String(), + }, + }, + }, + req: &accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest{ + Subjects: []string{types.KindAccessRequest}, + AutomaticReviewName: "automaticReviewIntegration", + NotificationName: "notificationIntegration", + }, + expectedRule: true, + }, + { + description: "filter by builtin automatic_review rules", + rule: &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "example-builtin-automatic_approval-rule", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "someCondition", + Notification: &accessmonitoringrulesv1.Notification{ + Name: "notificationIntegration", + }, + AutomaticReview: &accessmonitoringrulesv1.AutomaticReview{ + Integration: types.BuiltInAutomaticReview, + Decision: types.RequestState_APPROVED.String(), + }, + }, + }, + req: &accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest{ + Subjects: []string{types.KindAccessRequest}, + AutomaticReviewName: types.BuiltInAutomaticReview, + }, + expectedRule: true, + }, + { + description: "no match", + rule: &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "no-match-rule", + }, + Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ + Subjects: []string{types.KindAccessRequest}, + Condition: "someCondition", + AutomaticReview: &accessmonitoringrulesv1.AutomaticReview{ + Integration: types.BuiltInAutomaticReview, + Decision: types.RequestState_APPROVED.String(), + }, + }, + }, + req: &accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest{ + Subjects: []string{types.KindAccessRequest}, + AutomaticReviewName: "automaticReviewIntegration", + }, + expectedRule: false, + }, + } + + ctx := t.Context() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + _, err := p.accessMonitoringRules.CreateAccessMonitoringRule(ctx, test.rule) + require.NoError(t, err) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + results, next, err := p.cache.ListAccessMonitoringRules(ctx, 0, "") + assert.NoError(t, err) + assert.Empty(t, next) + assert.Len(t, results, 1) + }, + 15*time.Second, 100*time.Millisecond) + + rules, _, err := p.cache.ListAccessMonitoringRulesWithFilter(ctx, test.req) + require.NoError(t, err) + if test.expectedRule { + require.Len(t, rules, 1) + require.Empty(t, cmp.Diff(test.rule, rules[0], protocmp.Transform())) + } else { + require.Empty(t, rules) + } + + require.NoError(t, p.accessMonitoringRules.DeleteAccessMonitoringRule(ctx, test.rule.GetMetadata().GetName())) + }) + } +} diff --git a/lib/cache/cache.go b/lib/cache/cache.go index ce6ebc24c524f..1448475a567e1 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -37,7 +37,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" - accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" dbobjectv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobject/v1" identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1" kubewaitingcontainerpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/kubewaitingcontainer/v1" @@ -510,7 +509,6 @@ type Cache struct { webTokenCache types.WebTokenInterface dynamicWindowsDesktopsCache services.DynamicWindowsDesktops userGroupsCache services.UserGroups - integrationsCache services.Integrations userTasksCache services.UserTasks discoveryConfigsCache services.DiscoveryConfigs headlessAuthenticationsCache services.HeadlessAuthenticationService @@ -519,7 +517,6 @@ type Cache struct { eventsFanout *services.FanoutV2 lowVolumeEventsFanout *utils.RoundRobin[*services.FanoutV2] kubeWaitingContsCache *local.KubeWaitingContainerService - accessMontoringRuleCache services.AccessMonitoringRules staticHostUsersCache *local.StaticHostUserService provisioningStatesCache *local.ProvisioningStateService identityCenterCache *local.IdentityCenterService @@ -922,12 +919,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - integrationsCache, err := local.NewIntegrationsService(config.Backend, local.WithIntegrationsServiceCacheMode(true)) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - userTasksCache, err := local.NewUserTasksService(config.Backend) if err != nil { cancel() @@ -958,12 +949,6 @@ func New(config Config) (*Cache, error) { return nil, trace.Wrap(err) } - accessMonitoringRuleCache, err := local.NewAccessMonitoringRulesService(config.Backend) - if err != nil { - cancel() - return nil, trace.Wrap(err) - } - fanout := services.NewFanoutV2(services.FanoutV2Config{}) lowVolumeFanouts := make([]*services.FanoutV2, 0, config.FanoutShards) for i := 0; i < config.FanoutShards; i++ { @@ -1034,9 +1019,7 @@ func New(config Config) (*Cache, error) { restrictionsCache: local.NewRestrictionsService(config.Backend), webTokenCache: identityService.WebTokens(), dynamicWindowsDesktopsCache: dynamicDesktopsService, - accessMontoringRuleCache: accessMonitoringRuleCache, userGroupsCache: userGroupsCache, - integrationsCache: integrationsCache, userTasksCache: userTasksCache, discoveryConfigsCache: discoveryConfigsCache, headlessAuthenticationsCache: identityService, @@ -2124,32 +2107,6 @@ func (c *Cache) ListDynamicWindowsDesktops(ctx context.Context, pageSize int, ne return rg.reader.ListDynamicWindowsDesktops(ctx, pageSize, nextPage) } -// ListIntegrations returns a paginated list of all Integrations resources. -func (c *Cache) ListIntegrations(ctx context.Context, pageSize int, nextKey string) ([]types.Integration, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListIntegrations") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.integrations) - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - return rg.reader.ListIntegrations(ctx, pageSize, nextKey) -} - -// GetIntegration returns the specified Integration resources. -func (c *Cache) GetIntegration(ctx context.Context, name string) (types.Integration, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetIntegration") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.integrations) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetIntegration(ctx, name) -} - // ListUserTasks returns a list of UserTask resources. func (c *Cache) ListUserTasks(ctx context.Context, pageSize int64, nextKey string, filters *usertasksv1.ListUserTasksFilters) ([]*usertasksv1.UserTask, string, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListUserTasks") @@ -2356,49 +2313,6 @@ func (c *Cache) GetUserLoginState(ctx context.Context, name string) (*userlogins return uls, trace.Wrap(err) } -// ListAccessMonitoringRules returns a paginated list of access monitoring rules. -func (c *Cache) ListAccessMonitoringRules(ctx context.Context, pageSize int, nextToken string) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAccessMonitoringRules") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessMonitoringRules) - - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - out, nextKey, err := rg.reader.ListAccessMonitoringRules(ctx, pageSize, nextToken) - return out, nextKey, trace.Wrap(err) -} - -// ListAccessMonitoringRulesWithFilter returns a paginated list of access monitoring rules. -func (c *Cache) ListAccessMonitoringRulesWithFilter(ctx context.Context, req *accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) { - ctx, span := c.Tracer.Start(ctx, "cache/ListAccessMonitoringRules") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessMonitoringRules) - - if err != nil { - return nil, "", trace.Wrap(err) - } - defer rg.Release() - out, nextKey, err := rg.reader.ListAccessMonitoringRulesWithFilter(ctx, req) - return out, nextKey, trace.Wrap(err) -} - -// GetAccessMonitoringRule returns the specified AccessMonitoringRule resources. -func (c *Cache) GetAccessMonitoringRule(ctx context.Context, name string) (*accessmonitoringrulesv1.AccessMonitoringRule, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetAccessMonitoringRule") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.accessMonitoringRules) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - return rg.reader.GetAccessMonitoringRule(ctx, name) -} - // ListResources is a part of auth.Cache implementation func (c *Cache) ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListResources") diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index abb11e43e1b43..b586bf06a1bc7 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1321,44 +1321,6 @@ func TestLocks(t *testing.T) { }) } -// TestIntegrations tests that CRUD operations on integrations resources are -// replicated from the backend to the cache. -func TestIntegrations(t *testing.T) { - t.Parallel() - - p := newTestPack(t, ForAuth) - t.Cleanup(p.Close) - - testResources(t, p, testFuncs[types.Integration]{ - newResource: func(name string) (types.Integration, error) { - return types.NewIntegrationAWSOIDC( - types.Metadata{Name: name}, - &types.AWSOIDCIntegrationSpecV1{ - RoleARN: "arn:aws:iam::123456789012:role/OpsTeam", - }, - ) - }, - create: func(ctx context.Context, i types.Integration) error { - _, err := p.integrations.CreateIntegration(ctx, i) - return err - }, - list: func(ctx context.Context) ([]types.Integration, error) { - results, _, err := p.integrations.ListIntegrations(ctx, 0, "") - return results, err - }, - cacheGet: p.cache.GetIntegration, - cacheList: func(ctx context.Context) ([]types.Integration, error) { - results, _, err := p.cache.ListIntegrations(ctx, 0, "") - return results, err - }, - update: func(ctx context.Context, i types.Integration) error { - _, err := p.integrations.UpdateIntegration(ctx, i) - return err - }, - deleteAll: p.integrations.DeleteAllIntegrations, - }) -} - // TestUserTasks tests that CRUD operations on user notification resources are // replicated from the backend to the cache. func TestUserTasks(t *testing.T) { @@ -2799,8 +2761,15 @@ func newGlobalNotification(t *testing.T, title string) *notificationsv1.GlobalNo func newAccessMonitoringRule(t *testing.T) *accessmonitoringrulesv1.AccessMonitoringRule { t.Helper() notification := &accessmonitoringrulesv1.AccessMonitoringRule{ + Kind: types.KindAccessMonitoringRule, + Version: types.V1, + Metadata: &headerv1.Metadata{}, Spec: &accessmonitoringrulesv1.AccessMonitoringRuleSpec{ - Notification: &accessmonitoringrulesv1.Notification{}, + Notification: &accessmonitoringrulesv1.Notification{ + Name: "test", + }, + Subjects: []string{"llama", "shark"}, + Condition: "test", }, } return notification diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 865d38c51b5aa..25defefd63856 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -21,6 +21,7 @@ import ( "github.com/gravitational/trace" + accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" autoupdatev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1" clusterconfigv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1" crownjewelv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/crownjewel/v1" @@ -103,6 +104,9 @@ type collections struct { accessListReviews *collection[*accesslist.Review, accessListReviewIndex] crownJewels *collection[*crownjewelv1.CrownJewel, crownJewelIndex] accessGraphSettings *collection[*clusterconfigv1.AccessGraphSettings, accessGraphSettingsIndex] + integrations *collection[types.Integration, integrationIndex] + pluginStaticCredentials *collection[types.PluginStaticCredentials, pluginStaticCredentialsIndex] + accessMonitoringRules *collection[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleIndex] } // setupCollections ensures that the appropriate [collection] is @@ -488,6 +492,30 @@ func setupCollections(c Config) (*collections, error) { out.accessGraphSettings = collect out.byKind[resourceKind] = out.accessGraphSettings + case types.KindIntegration: + collect, err := newIntegrationCollection(c.Integrations, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.integrations = collect + out.byKind[resourceKind] = out.integrations + case types.KindPluginStaticCredentials: + collect, err := newPluginStaticCredentialsCollection(c.PluginStaticCredentials, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.pluginStaticCredentials = collect + out.byKind[resourceKind] = out.pluginStaticCredentials + case types.KindAccessMonitoringRule: + collect, err := newAccessMonitoringRuleCollection(c.AccessMonitoringRules, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.accessMonitoringRules = collect + out.byKind[resourceKind] = out.accessMonitoringRules } } diff --git a/lib/cache/generic_operations.go b/lib/cache/generic_operations.go index 5b3de6b532d36..bdb5a73075579 100644 --- a/lib/cache/generic_operations.go +++ b/lib/cache/generic_operations.go @@ -83,6 +83,9 @@ type genericLister[T any, I comparable] struct { nextToken func(T) string // clone is used to make a copy of the item returned. clone func(T) T + // filter is an optional function used to exclude items from + // cache reads. + filter func(T) bool } // list retrieves a page of items from the configured cache collection. @@ -115,6 +118,9 @@ func (l genericLister[T, I]) list(ctx context.Context, pageSize int, startToken return out, l.nextToken(sf), nil } + if l.filter != nil && !l.filter(sf) { + continue + } out = append(out, l.clone(sf)) } diff --git a/lib/cache/integrations.go b/lib/cache/integrations.go new file mode 100644 index 0000000000000..8378e956883bc --- /dev/null +++ b/lib/cache/integrations.go @@ -0,0 +1,111 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +type integrationIndex string + +const integrationNameIndex integrationIndex = "name" + +func newIntegrationCollection(upstream services.Integrations, w types.WatchKind) (*collection[types.Integration, integrationIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter Integrations") + } + + return &collection[types.Integration, integrationIndex]{ + store: newStore(map[integrationIndex]func(types.Integration) string{ + integrationNameIndex: func(r types.Integration) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.Integration, error) { + var startKey string + var resources []types.Integration + for { + var igs []types.Integration + var err error + igs, startKey, err = upstream.ListIntegrations(ctx, 0, startKey) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, igs...) + + if startKey == "" { + break + } + } + + return resources, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.Integration { + return &types.IntegrationV1{ + ResourceHeader: types.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + }, + } + }, + watch: w, + }, nil +} + +// ListIntegrations returns a paginated list of all Integrations resources. +func (c *Cache) ListIntegrations(ctx context.Context, pageSize int, pageToken string) ([]types.Integration, string, error) { + ctx, span := c.Tracer.Start(ctx, "cache/ListIntegrations") + defer span.End() + + lister := genericLister[types.Integration, integrationIndex]{ + cache: c, + collection: c.collections.integrations, + index: integrationNameIndex, + upstreamList: c.Config.Integrations.ListIntegrations, + nextToken: func(t types.Integration) string { + return t.GetMetadata().Name + }, + clone: types.Integration.Clone, + } + out, next, err := lister.list(ctx, pageSize, pageToken) + return out, next, trace.Wrap(err) +} + +// GetIntegration returns the specified Integration resources. +func (c *Cache) GetIntegration(ctx context.Context, name string) (types.Integration, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetIntegration") + defer span.End() + + getter := genericGetter[types.Integration, integrationIndex]{ + cache: c, + collection: c.collections.integrations, + index: integrationNameIndex, + upstreamGet: c.Config.Integrations.GetIntegration, + clone: types.Integration.Clone, + } + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) +} diff --git a/lib/cache/integrations_test.go b/lib/cache/integrations_test.go new file mode 100644 index 0000000000000..0503c55218837 --- /dev/null +++ b/lib/cache/integrations_test.go @@ -0,0 +1,62 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/teleport/api/types" +) + +// TestIntegrations tests that CRUD operations on integrations resources are +// replicated from the backend to the cache. +func TestIntegrations(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.Integration]{ + newResource: func(name string) (types.Integration, error) { + return types.NewIntegrationAWSOIDC( + types.Metadata{Name: name}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:iam::123456789012:role/OpsTeam", + }, + ) + }, + create: func(ctx context.Context, i types.Integration) error { + _, err := p.integrations.CreateIntegration(ctx, i) + return err + }, + list: func(ctx context.Context) ([]types.Integration, error) { + results, _, err := p.integrations.ListIntegrations(ctx, 0, "") + return results, err + }, + cacheGet: p.cache.GetIntegration, + cacheList: func(ctx context.Context) ([]types.Integration, error) { + results, _, err := p.cache.ListIntegrations(ctx, 0, "") + return results, err + }, + update: func(ctx context.Context, i types.Integration) error { + _, err := p.integrations.UpdateIntegration(ctx, i) + return err + }, + deleteAll: p.integrations.DeleteAllIntegrations, + }) +} diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index 64ffea33d356d..917d2ad51512f 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -26,7 +26,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" - accessmonitoringrulesv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" dbobjectv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/dbobject/v1" identitycenterv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/identitycenter/v1" kubewaitingcontainerpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/kubewaitingcontainer/v1" @@ -104,22 +103,18 @@ type legacyCollections struct { databaseObjects collectionReader[services.DatabaseObjectsGetter] discoveryConfigs collectionReader[services.DiscoveryConfigsGetter] installers collectionReader[installerGetter] - integrations collectionReader[services.IntegrationsGetter] userTasks collectionReader[userTasksGetter] kubeWaitingContainers collectionReader[kubernetesWaitingContainerGetter] staticHostUsers collectionReader[staticHostUserGetter] locks collectionReader[services.LockGetter] networkRestrictions collectionReader[networkRestrictionGetter] - proxies collectionReader[services.ProxyGetter] remoteClusters collectionReader[remoteClusterGetter] uiConfigs collectionReader[uiConfigGetter] userLoginStates collectionReader[services.UserLoginStatesGetter] webTokens collectionReader[webTokenGetter] dynamicWindowsDesktops collectionReader[dynamicWindowsDesktopsGetter] - accessMonitoringRules collectionReader[accessMonitoringRuleGetter] provisioningStates collectionReader[provisioningStateGetter] identityCenterPrincipalAssignments collectionReader[identityCenterPrincipalAssignmentGetter] - pluginStaticCredentials collectionReader[pluginStaticCredentialsGetter] gitServers collectionReader[services.GitServerGetter] } @@ -218,15 +213,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.dynamicWindowsDesktops - case types.KindIntegration: - if c.Integrations == nil { - return nil, trace.BadParameter("missing parameter Integrations") - } - collections.integrations = &genericCollection[types.Integration, services.IntegrationsGetter, integrationsExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.integrations case types.KindUserTask: if c.UserTasks == nil { return nil, trace.BadParameter("missing parameter user tasks") @@ -287,13 +273,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.staticHostUsers - case types.KindAccessMonitoringRule: - if c.AccessMonitoringRules == nil { - return nil, trace.BadParameter("missing parameter AccessMonitoringRule") - } - collections.accessMonitoringRules = &genericCollection[*accessmonitoringrulesv1.AccessMonitoringRule, accessMonitoringRuleGetter, accessMonitoringRulesExecutor]{cache: c, watch: watch} - collections.byKind[resourceKind] = collections.accessMonitoringRules - case types.KindProvisioningPrincipalState: if c.ProvisioningStates == nil { return nil, trace.BadParameter("missing parameter KindProvisioningState") @@ -316,21 +295,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect watch: watch, } collections.byKind[resourceKind] = collections.identityCenterPrincipalAssignments - - case types.KindPluginStaticCredentials: - if c.PluginStaticCredentials == nil { - return nil, trace.BadParameter("missing parameter PluginStaticCredentials") - } - collections.pluginStaticCredentials = &genericCollection[ - types.PluginStaticCredentials, - pluginStaticCredentialsGetter, - pluginStaticCredentialsExecutor, - ]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.pluginStaticCredentials - case types.KindGitServer: if c.GitServers == nil { return nil, trace.BadParameter("missing parameter GitServers") @@ -973,58 +937,6 @@ type resourceGetter interface { ListResources(ctx context.Context, req proto.ListResourcesRequest) (*types.ListResourcesResponse, error) } -type integrationsExecutor struct{} - -func (integrationsExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.Integration, error) { - var ( - startKey string - resources []types.Integration - ) - for { - var igs []types.Integration - var err error - igs, startKey, err = cache.Integrations.ListIntegrations(ctx, 0, startKey) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, igs...) - - if startKey == "" { - break - } - } - - return resources, nil -} - -func (integrationsExecutor) upsert(ctx context.Context, cache *Cache, resource types.Integration) error { - _, err := cache.integrationsCache.CreateIntegration(ctx, resource) - if trace.IsAlreadyExists(err) { - _, err = cache.integrationsCache.UpdateIntegration(ctx, resource) - } - return trace.Wrap(err) -} - -func (integrationsExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.integrationsCache.DeleteAllIntegrations(ctx) -} - -func (integrationsExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.integrationsCache.DeleteIntegration(ctx, resource.GetName()) -} - -func (integrationsExecutor) isSingleton() bool { return false } - -func (integrationsExecutor) getReader(cache *Cache, cacheOK bool) services.IntegrationsGetter { - if cacheOK { - return cache.integrationsCache - } - return cache.Config.Integrations -} - -var _ executor[types.Integration, services.IntegrationsGetter] = integrationsExecutor{} - type discoveryConfigExecutor struct{} func (discoveryConfigExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*discoveryconfig.DiscoveryConfig, error) { @@ -1265,52 +1177,3 @@ func (userLoginStateExecutor) getReader(cache *Cache, cacheOK bool) services.Use } var _ executor[*userloginstate.UserLoginState, services.UserLoginStatesGetter] = userLoginStateExecutor{} - -type accessMonitoringRulesExecutor struct{} - -func (accessMonitoringRulesExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]*accessmonitoringrulesv1.AccessMonitoringRule, error) { - var resources []*accessmonitoringrulesv1.AccessMonitoringRule - var nextToken string - for { - var page []*accessmonitoringrulesv1.AccessMonitoringRule - var err error - page, nextToken, err = cache.AccessMonitoringRules.ListAccessMonitoringRules(ctx, 0 /* page size */, nextToken) - if err != nil { - return nil, trace.Wrap(err) - } - resources = append(resources, page...) - - if nextToken == "" { - break - } - } - return resources, nil -} - -func (accessMonitoringRulesExecutor) upsert(ctx context.Context, cache *Cache, resource *accessmonitoringrulesv1.AccessMonitoringRule) error { - _, err := cache.accessMontoringRuleCache.UpsertAccessMonitoringRule(ctx, resource) - return trace.Wrap(err) -} - -func (accessMonitoringRulesExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.accessMontoringRuleCache.DeleteAllAccessMonitoringRules(ctx) -} - -func (accessMonitoringRulesExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.accessMontoringRuleCache.DeleteAccessMonitoringRule(ctx, resource.GetName()) -} - -func (accessMonitoringRulesExecutor) isSingleton() bool { return false } - -func (accessMonitoringRulesExecutor) getReader(cache *Cache, cacheOK bool) accessMonitoringRuleGetter { - if cacheOK { - return cache.accessMontoringRuleCache - } - return cache.Config.AccessMonitoringRules -} - -type accessMonitoringRuleGetter interface { - GetAccessMonitoringRule(ctx context.Context, name string) (*accessmonitoringrulesv1.AccessMonitoringRule, error) - ListAccessMonitoringRules(ctx context.Context, limit int, startKey string) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) - ListAccessMonitoringRulesWithFilter(ctx context.Context, req *accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) -} diff --git a/lib/cache/plugin_static_credentials.go b/lib/cache/plugin_static_credentials.go index a6cb1ee161a81..1d9bd1409e5b7 100644 --- a/lib/cache/plugin_static_credentials.go +++ b/lib/cache/plugin_static_credentials.go @@ -22,63 +22,80 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" ) +type pluginStaticCredentialsIndex string + +const pluginStaticCredentialsNameIndex pluginStaticCredentialsIndex = "name" + +func newPluginStaticCredentialsCollection(upstream services.PluginStaticCredentials, w types.WatchKind) (*collection[types.PluginStaticCredentials, pluginStaticCredentialsIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter PluginStaticCredentials") + } + + return &collection[types.PluginStaticCredentials, pluginStaticCredentialsIndex]{ + store: newStore(map[pluginStaticCredentialsIndex]func(types.PluginStaticCredentials) string{ + pluginStaticCredentialsNameIndex: func(r types.PluginStaticCredentials) string { + return r.GetMetadata().Name + }, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.PluginStaticCredentials, error) { + creds, err := upstream.GetAllPluginStaticCredentials(ctx) + return creds, trace.Wrap(err) + + }, + headerTransform: func(hdr *types.ResourceHeader) types.PluginStaticCredentials { + return &types.PluginStaticCredentialsV1{ + ResourceHeader: types.ResourceHeader{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + Description: hdr.Metadata.Description, + }, + }, + } + }, + watch: w, + }, nil +} + func (c *Cache) GetPluginStaticCredentials(ctx context.Context, name string) (types.PluginStaticCredentials, error) { ctx, span := c.Tracer.Start(ctx, "cache/GetPluginStaticCredentials") defer span.End() - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.pluginStaticCredentials) - if err != nil { - return nil, trace.Wrap(err) + getter := genericGetter[types.PluginStaticCredentials, pluginStaticCredentialsIndex]{ + cache: c, + collection: c.collections.pluginStaticCredentials, + index: pluginStaticCredentialsNameIndex, + upstreamGet: c.Config.PluginStaticCredentials.GetPluginStaticCredentials, + clone: types.PluginStaticCredentials.Clone, } - defer rg.Release() - return rg.reader.GetPluginStaticCredentials(ctx, name) + out, err := getter.get(ctx, name) + return out, trace.Wrap(err) } func (c *Cache) GetPluginStaticCredentialsByLabels(ctx context.Context, labels map[string]string) ([]types.PluginStaticCredentials, error) { ctx, span := c.Tracer.Start(ctx, "cache/GetPluginStaticCredentialsByLabels") defer span.End() - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.pluginStaticCredentials) + rg, err := acquireReadGuard(c, c.collections.pluginStaticCredentials) if err != nil { return nil, trace.Wrap(err) } defer rg.Release() - return rg.reader.GetPluginStaticCredentialsByLabels(ctx, labels) -} - -type pluginStaticCredentialsGetter interface { - GetPluginStaticCredentials(ctx context.Context, name string) (types.PluginStaticCredentials, error) - GetPluginStaticCredentialsByLabels(ctx context.Context, labels map[string]string) ([]types.PluginStaticCredentials, error) -} -var _ executor[types.PluginStaticCredentials, pluginStaticCredentialsGetter] = pluginStaticCredentialsExecutor{} - -type pluginStaticCredentialsExecutor struct{} - -func (pluginStaticCredentialsExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.PluginStaticCredentials, error) { - return cache.PluginStaticCredentials.GetAllPluginStaticCredentials(ctx) -} - -func (pluginStaticCredentialsExecutor) upsert(ctx context.Context, cache *Cache, resource types.PluginStaticCredentials) error { - _, err := cache.pluginStaticCredentialsCache.UpsertPluginStaticCredentials(ctx, resource) - return trace.Wrap(err) -} - -func (pluginStaticCredentialsExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.pluginStaticCredentialsCache.DeleteAllPluginStaticCredentials(ctx) -} - -func (pluginStaticCredentialsExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.pluginStaticCredentialsCache.DeletePluginStaticCredentials(ctx, resource.GetName()) -} - -func (pluginStaticCredentialsExecutor) isSingleton() bool { return false } + if !rg.ReadCache() { + resp, err := c.Config.PluginStaticCredentials.GetPluginStaticCredentialsByLabels(ctx, labels) + return resp, trace.Wrap(err) + } -func (pluginStaticCredentialsExecutor) getReader(cache *Cache, cacheOK bool) pluginStaticCredentialsGetter { - if cacheOK { - return cache.pluginStaticCredentialsCache + var out []types.PluginStaticCredentials + for cred := range rg.store.resources(pluginStaticCredentialsNameIndex, "", "") { + if types.MatchLabels(cred, labels) { + out = append(out, cred.Clone()) + } } - return cache.Config.PluginStaticCredentials + return out, nil } diff --git a/lib/cache/plugin_static_credentials_test.go b/lib/cache/plugin_static_credentials_test.go index c7b05119d092d..4291fc1d4c334 100644 --- a/lib/cache/plugin_static_credentials_test.go +++ b/lib/cache/plugin_static_credentials_test.go @@ -89,8 +89,14 @@ func TestPluginStaticCredentials(t *testing.T) { return err }, deleteAll: p.pluginStaticCredentials.DeleteAllPluginStaticCredentials, - cacheList: p.cache.pluginStaticCredentialsCache.GetAllPluginStaticCredentials, - cacheGet: cacheGet.fn, + cacheList: func(ctx context.Context) ([]types.PluginStaticCredentials, error) { + var out []types.PluginStaticCredentials + for cred := range p.cache.collections.pluginStaticCredentials.store.resources(pluginStaticCredentialsNameIndex, "", "") { + out = append(out, cred.Clone()) + } + return out, nil + }, + cacheGet: cacheGet.fn, changeResource: func(cred types.PluginStaticCredentials) { // types.PluginStaticCredentials does not support Expires. Let's // use labels. diff --git a/lib/services/access_monitoring_rules.go b/lib/services/access_monitoring_rules.go index b5143e7eb9932..90de339aa3b73 100644 --- a/lib/services/access_monitoring_rules.go +++ b/lib/services/access_monitoring_rules.go @@ -121,3 +121,27 @@ func MarshalAccessMonitoringRule(accessMonitoringRule *accessmonitoringrulesv1.A func UnmarshalAccessMonitoringRule(data []byte, opts ...MarshalOption) (*accessmonitoringrulesv1.AccessMonitoringRule, error) { return FastUnmarshalProtoResourceDeprecated[*accessmonitoringrulesv1.AccessMonitoringRule](data, opts...) } + +// MatchAccessMonitoringRule returns true if the provided rule matches the provided match fields. +// The match fields are optional. If a match field is not provided, then the +// rule matches any value for that field. +func MatchAccessMonitoringRule(rule *accessmonitoringrulesv1.AccessMonitoringRule, subjects []string, notificationIntegration, automaticReviewIntegration string) bool { + if notificationIntegration != "" { + if rule.GetSpec().GetNotification().GetName() != notificationIntegration { + return false + } + } + if automaticReviewIntegration != "" { + if rule.GetSpec().GetAutomaticReview().GetIntegration() != automaticReviewIntegration { + return false + } + } + for _, subject := range subjects { + if ok := slices.ContainsFunc(rule.Spec.Subjects, func(s string) bool { + return s == subject + }); !ok { + return false + } + } + return true +} diff --git a/lib/services/access_monitoring_rules_test.go b/lib/services/access_monitoring_rules_test.go index bc7af2cc27691..8a6b35e2f10e8 100644 --- a/lib/services/access_monitoring_rules_test.go +++ b/lib/services/access_monitoring_rules_test.go @@ -44,7 +44,7 @@ func TestValidateAccessMonitoringRule(t *testing.T) { modifyAMR: func(amr *accessmonitoringrulesv1.AccessMonitoringRule) { amr.Spec.Notification.Name = "" }, - assertErr: func(t require.TestingT, err error, i ...interface{}) { + assertErr: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "notification plugin name is missing") }, }, @@ -53,7 +53,7 @@ func TestValidateAccessMonitoringRule(t *testing.T) { modifyAMR: func(amr *accessmonitoringrulesv1.AccessMonitoringRule) { amr.Spec.AutomaticReview.Integration = "" }, - assertErr: func(t require.TestingT, err error, i ...interface{}) { + assertErr: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "automatic_review integration is missing") }, }, @@ -62,7 +62,7 @@ func TestValidateAccessMonitoringRule(t *testing.T) { modifyAMR: func(amr *accessmonitoringrulesv1.AccessMonitoringRule) { amr.Spec.AutomaticReview.Decision = "" }, - assertErr: func(t require.TestingT, err error, i ...interface{}) { + assertErr: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "automatic_review decision is missing") }, }, @@ -72,7 +72,7 @@ func TestValidateAccessMonitoringRule(t *testing.T) { amr.Spec.Notification = nil amr.Spec.AutomaticReview = nil }, - assertErr: func(t require.TestingT, err error, i ...interface{}) { + assertErr: func(t require.TestingT, err error, i ...any) { require.ErrorContains(t, err, "notification or automatic_review must be configured") }, }, diff --git a/lib/services/local/access_monitoring_rules.go b/lib/services/local/access_monitoring_rules.go index 63e566e2dca48..4c938c7e54b0b 100644 --- a/lib/services/local/access_monitoring_rules.go +++ b/lib/services/local/access_monitoring_rules.go @@ -20,7 +20,6 @@ package local import ( "context" - "slices" "github.com/gravitational/trace" @@ -109,34 +108,10 @@ func (s *AccessMonitoringRulesService) DeleteAllAccessMonitoringRules(ctx contex func (s *AccessMonitoringRulesService) ListAccessMonitoringRulesWithFilter(ctx context.Context, req *accessmonitoringrulesv1.ListAccessMonitoringRulesWithFilterRequest) ([]*accessmonitoringrulesv1.AccessMonitoringRule, string, error) { resources, nextKey, err := s.svc.ListResourcesWithFilter(ctx, int(req.GetPageSize()), req.GetPageToken(), func(resource *accessmonitoringrulesv1.AccessMonitoringRule) bool { - return match(resource, req.GetSubjects(), req.GetNotificationName(), req.GetAutomaticReviewName()) + return services.MatchAccessMonitoringRule(resource, req.GetSubjects(), req.GetNotificationName(), req.GetAutomaticReviewName()) }) if err != nil { return nil, "", trace.Wrap(err) } return resources, nextKey, nil } - -// match returns true if the provided rule matches the provided match fields. -// The match fields are optional. If a match field is not provided, then the -// rule matches any value for that field. -func match(rule *accessmonitoringrulesv1.AccessMonitoringRule, subjects []string, notificationIntegration, automaticReviewIntegration string) bool { - if notificationIntegration != "" { - if rule.GetSpec().GetNotification().GetName() != notificationIntegration { - return false - } - } - if automaticReviewIntegration != "" { - if rule.GetSpec().GetAutomaticReview().GetIntegration() != automaticReviewIntegration { - return false - } - } - for _, subject := range subjects { - if ok := slices.ContainsFunc(rule.Spec.Subjects, func(s string) bool { - return s == subject - }); !ok { - return false - } - } - return true -}