diff --git a/api/types/restrictions.go b/api/types/restrictions.go index adc7e221eb4d0..e7f23045b37e6 100644 --- a/api/types/restrictions.go +++ b/api/types/restrictions.go @@ -20,6 +20,8 @@ import ( "time" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils" ) // NetworkRestrictions defines network restrictions applied to SSH session. @@ -33,6 +35,8 @@ type NetworkRestrictions interface { GetDeny() []AddressCondition // SetDeny sets a list of denied network addresses (overrides Allow list) SetDeny(deny []AddressCondition) + // Clone returns a copy of the network restrictions. + Clone() NetworkRestrictions } // NewNetworkRestrictions creates a new NetworkRestrictions with the given name. @@ -46,6 +50,11 @@ func NewNetworkRestrictions() NetworkRestrictions { } } +// Clone returns a copy of the network restrictions. +func (r *NetworkRestrictionsV4) Clone() NetworkRestrictions { + return utils.CloneProtoMsg(r) +} + func (r *NetworkRestrictionsV4) setStaticFields() { if r.Version == "" { r.Version = V4 diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 604502e8c139b..ecb6335e30972 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -501,7 +501,6 @@ type Cache struct { accessCache services.Access dynamicAccessCache services.DynamicAccessExt presenceCache services.Presence - restrictionsCache services.Restrictions userGroupsCache services.UserGroups discoveryConfigsCache services.DiscoveryConfigs headlessAuthenticationsCache services.HeadlessAuthenticationService @@ -970,7 +969,6 @@ func New(config Config) (*Cache, error) { accessCache: local.NewAccessService(config.Backend), dynamicAccessCache: local.NewDynamicAccessService(config.Backend), presenceCache: local.NewPresenceService(config.Backend), - restrictionsCache: local.NewRestrictionsService(config.Backend), userGroupsCache: userGroupsCache, discoveryConfigsCache: discoveryConfigsCache, headlessAuthenticationsCache: identityService, @@ -1733,20 +1731,6 @@ func (c *Cache) processEvent(ctx context.Context, event types.Event) error { return nil } -// GetNetworkRestrictions gets the network restrictions. -func (c *Cache) GetNetworkRestrictions(ctx context.Context) (types.NetworkRestrictions, error) { - ctx, span := c.Tracer.Start(ctx, "cache/GetNetworkRestrictions") - defer span.End() - - rg, err := readLegacyCollectionCache(c, c.legacyCacheCollections.networkRestrictions) - if err != nil { - return nil, trace.Wrap(err) - } - defer rg.Release() - - return rg.reader.GetNetworkRestrictions(ctx) -} - // ListDiscoveryConfigs returns a paginated list of all DiscoveryConfig resources. func (c *Cache) ListDiscoveryConfigs(ctx context.Context, pageSize int, nextKey string) ([]*discoveryconfig.DiscoveryConfig, string, error) { ctx, span := c.Tracer.Start(ctx, "cache/ListDiscoveryConfigs") diff --git a/lib/cache/cert_authority_test.go b/lib/cache/cert_authority_test.go index a96e87685d179..fffe57bed9eb4 100644 --- a/lib/cache/cert_authority_test.go +++ b/lib/cache/cert_authority_test.go @@ -93,7 +93,7 @@ func TestNodeCAFiltering(t *testing.T) { Access: p.cache.accessCache, DynamicAccess: p.cache.dynamicAccessCache, Presence: p.cache.presenceCache, - Restrictions: p.cache.restrictionsCache, + Restrictions: p.cache.Restrictions, SAMLIdPServiceProviders: p.samlIDPServiceProviders, UserGroups: p.userGroups, StaticHostUsers: p.staticHostUsers, diff --git a/lib/cache/collections.go b/lib/cache/collections.go index aafe55eb2c809..5f21360decd8d 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -125,6 +125,7 @@ type collections struct { gitServers *collection[types.Server, gitServerIndex] databaseObjects *collection[*dbobjectv1.DatabaseObject, databaseObjectIndex] staticHostUsers *collection[*userprovisioningv2.StaticHostUser, staticHostUserIndex] + networkRestrictions *collection[types.NetworkRestrictions, networkingRestrictionIndex] } // setupCollections ensures that the appropriate [collection] is @@ -639,6 +640,14 @@ func setupCollections(c Config) (*collections, error) { out.staticHostUsers = collect out.byKind[resourceKind] = out.staticHostUsers + case types.KindNetworkRestrictions: + collect, err := newNetworkingRestrictionCollection(c.Restrictions, watch) + if err != nil { + return nil, trace.Wrap(err) + } + + out.networkRestrictions = collect + out.byKind[resourceKind] = out.networkRestrictions } } diff --git a/lib/cache/legacy_collections.go b/lib/cache/legacy_collections.go index 27da52acf3fce..f36231e6eb58c 100644 --- a/lib/cache/legacy_collections.go +++ b/lib/cache/legacy_collections.go @@ -95,7 +95,6 @@ type legacyCollections struct { secReports collectionReader[services.SecurityReportGetter] secReportsStates collectionReader[services.SecurityReportStateGetter] discoveryConfigs collectionReader[services.DiscoveryConfigsGetter] - networkRestrictions collectionReader[networkRestrictionGetter] provisioningStates collectionReader[provisioningStateGetter] identityCenterPrincipalAssignments collectionReader[identityCenterPrincipalAssignmentGetter] } @@ -113,15 +112,6 @@ func setupLegacyCollections(c *Cache, watches []types.WatchKind) (*legacyCollect return nil, trace.BadParameter("missing parameter DynamicAccess") } collections.byKind[resourceKind] = &genericCollection[types.AccessRequest, noReader, accessRequestExecutor]{cache: c, watch: watch} - case types.KindNetworkRestrictions: - if c.Restrictions == nil { - return nil, trace.BadParameter("missing parameter Restrictions") - } - collections.networkRestrictions = &genericCollection[types.NetworkRestrictions, networkRestrictionGetter, networkRestrictionsExecutor]{ - cache: c, - watch: watch, - } - collections.byKind[resourceKind] = collections.networkRestrictions case types.KindDiscoveryConfig: if c.DiscoveryConfigs == nil { return nil, trace.BadParameter("missing parameter DiscoveryConfigs") @@ -284,43 +274,6 @@ type userGetter interface { var _ executor[types.User, userGetter] = userExecutor{} -type networkRestrictionsExecutor struct{} - -func (networkRestrictionsExecutor) getAll(ctx context.Context, cache *Cache, loadSecrets bool) ([]types.NetworkRestrictions, error) { - restrictions, err := cache.Restrictions.GetNetworkRestrictions(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - return []types.NetworkRestrictions{restrictions}, nil -} - -func (networkRestrictionsExecutor) upsert(ctx context.Context, cache *Cache, resource types.NetworkRestrictions) error { - return cache.restrictionsCache.SetNetworkRestrictions(ctx, resource) -} - -func (networkRestrictionsExecutor) deleteAll(ctx context.Context, cache *Cache) error { - return cache.restrictionsCache.DeleteNetworkRestrictions(ctx) -} - -func (networkRestrictionsExecutor) delete(ctx context.Context, cache *Cache, resource types.Resource) error { - return cache.restrictionsCache.DeleteNetworkRestrictions(ctx) -} - -func (networkRestrictionsExecutor) isSingleton() bool { return true } - -func (networkRestrictionsExecutor) getReader(cache *Cache, cacheOK bool) networkRestrictionGetter { - if cacheOK { - return cache.restrictionsCache - } - return cache.Config.Restrictions -} - -type networkRestrictionGetter interface { - GetNetworkRestrictions(context.Context) (types.NetworkRestrictions, error) -} - -var _ executor[types.NetworkRestrictions, networkRestrictionGetter] = networkRestrictionsExecutor{} - // collectionReader extends the collection interface, adding routing capabilities. type collectionReader[R any] interface { legacyCollection diff --git a/lib/cache/network_restrictions.go b/lib/cache/network_restrictions.go new file mode 100644 index 0000000000000..b65f1b919dd6e --- /dev/null +++ b/lib/cache/network_restrictions.go @@ -0,0 +1,78 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" +) + +type networkingRestrictionIndex string + +const networkingRestrictionNameIndex networkingRestrictionIndex = "name" + +func newNetworkingRestrictionCollection(upstream services.Restrictions, w types.WatchKind) (*collection[types.NetworkRestrictions, networkingRestrictionIndex], error) { + if upstream == nil { + return nil, trace.BadParameter("missing parameter Restrictions") + } + + return &collection[types.NetworkRestrictions, networkingRestrictionIndex]{ + store: newStore(map[networkingRestrictionIndex]func(types.NetworkRestrictions) string{ + networkingRestrictionNameIndex: types.NetworkRestrictions.GetName, + }), + fetcher: func(ctx context.Context, loadSecrets bool) ([]types.NetworkRestrictions, error) { + restrictions, err := upstream.GetNetworkRestrictions(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return []types.NetworkRestrictions{restrictions}, nil + }, + headerTransform: func(hdr *types.ResourceHeader) types.NetworkRestrictions { + return &types.NetworkRestrictionsV4{ + Kind: hdr.Kind, + Version: hdr.Version, + Metadata: types.Metadata{ + Name: hdr.Metadata.Name, + }, + } + }, + watch: w, + }, nil +} + +// GetNetworkRestrictions gets the network restrictions. +func (c *Cache) GetNetworkRestrictions(ctx context.Context) (types.NetworkRestrictions, error) { + ctx, span := c.Tracer.Start(ctx, "cache/GetNetworkRestrictions") + defer span.End() + + getter := genericGetter[types.NetworkRestrictions, networkingRestrictionIndex]{ + cache: c, + collection: c.collections.networkRestrictions, + index: networkingRestrictionNameIndex, + upstreamGet: func(ctx context.Context, s string) (types.NetworkRestrictions, error) { + restriction, err := c.Config.Restrictions.GetNetworkRestrictions(ctx) + return restriction, trace.Wrap(err) + }, + clone: types.NetworkRestrictions.Clone, + } + out, err := getter.get(ctx, types.MetaNameNetworkRestrictions) + return out, trace.Wrap(err) +} diff --git a/lib/cache/networking_restrictions_test.go b/lib/cache/networking_restrictions_test.go new file mode 100644 index 0000000000000..7c2eab170f283 --- /dev/null +++ b/lib/cache/networking_restrictions_test.go @@ -0,0 +1,52 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cache + +import ( + "context" + "testing" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" +) + +func TestNetworkRestrictions(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForAuth) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.NetworkRestrictions]{ + newResource: func(name string) (types.NetworkRestrictions, error) { + return types.NewNetworkRestrictions(), nil + }, + create: p.restrictions.SetNetworkRestrictions, + list: func(ctx context.Context) ([]types.NetworkRestrictions, error) { + restrictions, err := p.restrictions.GetNetworkRestrictions(ctx) + return []types.NetworkRestrictions{restrictions}, trace.Wrap(err) + }, + cacheList: func(ctx context.Context) ([]types.NetworkRestrictions, error) { + restrictions, err := p.cache.GetNetworkRestrictions(ctx) + if trace.IsNotFound(err) { + return nil, nil + } + return []types.NetworkRestrictions{restrictions}, trace.Wrap(err) + }, + deleteAll: p.restrictions.DeleteNetworkRestrictions, + }) +} diff --git a/lib/services/local/events.go b/lib/services/local/events.go index 5de3afc1d9acd..de42b4c550d54 100644 --- a/lib/services/local/events.go +++ b/lib/services/local/events.go @@ -1842,16 +1842,11 @@ func (p *networkRestrictionsParser) match(key backend.Key) bool { func (p *networkRestrictionsParser) parse(event backend.Event) (types.Resource, error) { switch event.Type { case types.OpDelete: - name := event.Item.Key.TrimPrefix(backend.NewKey(restrictionsPrefix, network)).String() - if name == "" { - return nil, trace.NotFound("failed parsing %v", event.Item.Key.String()) - } - return &types.ResourceHeader{ Kind: types.KindNetworkRestrictions, Version: types.V1, Metadata: types.Metadata{ - Name: strings.TrimPrefix(name, backend.SeparatorString), + Name: types.MetaNameNetworkRestrictions, Namespace: apidefaults.Namespace, }, }, nil