diff --git a/lib/auth/init.go b/lib/auth/init.go index bd204b435d49c..21e3741d3b6be 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -247,15 +247,32 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e } domainName := cfg.ClusterName.GetClusterName() - lock, err := backend.AcquireLock(ctx, cfg.Backend, domainName, 30*time.Second) - if err != nil { + if err := backend.RunWhileLocked(ctx, + backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: cfg.Backend, + LockName: domainName, + TTL: 30 * time.Second, + }, + RefreshLockInterval: 20 * time.Second, + }, func(ctx context.Context) error { + return trace.Wrap(initCluster(ctx, cfg, asrv)) + }); err != nil { return nil, trace.Wrap(err) } - defer lock.Release(ctx, cfg.Backend) + return asrv, nil +} + +// initCluster configures the cluster based on the user provided configuration. This should +// only be called when the init lock is held to prevent multiple instances of Auth from attempting +// to bootstrap the cluster at the same time. +func initCluster(ctx context.Context, cfg InitConfig, asrv *Server) error { + span := oteltrace.SpanFromContext(ctx) + domainName := cfg.ClusterName.GetClusterName() firstStart, err := isFirstStart(ctx, asrv, cfg) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } // if bootstrap resources are supplied, use them to bootstrap backend state @@ -264,10 +281,10 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e if firstStart { log.Infof("Applying %v bootstrap resources (first initialization)", len(cfg.BootstrapResources)) if err := checkResourceConsistency(ctx, asrv.keyStore, domainName, cfg.BootstrapResources...); err != nil { - return nil, trace.Wrap(err, "refusing to bootstrap backend") + return trace.Wrap(err, "refusing to bootstrap backend") } if err := local.CreateResources(ctx, cfg.Backend, cfg.BootstrapResources...); err != nil { - return nil, trace.Wrap(err, "backend bootstrap failed") + return trace.Wrap(err, "backend bootstrap failed") } } else { log.Warnf("Ignoring %v bootstrap resources (previously initialized)", len(cfg.BootstrapResources)) @@ -279,7 +296,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e log.Infof("Applying %v resources (apply-on-startup)", len(cfg.ApplyOnStartupResources)) if err := applyResources(ctx, asrv.Services, cfg.ApplyOnStartupResources); err != nil { - return nil, trace.Wrap(err, "applying resources failed") + return trace.Wrap(err, "applying resources failed") } } @@ -291,7 +308,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e // singletons). However, we need to keep them around while Telekube uses them. for _, role := range cfg.Roles { if err := asrv.UpsertRole(ctx, role); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } log.Infof("Created role: %v.", role) } @@ -308,7 +325,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e // this part of code is only used in tests. if err := asrv.CreateCertAuthority(ca); err != nil { if !trace.IsAlreadyExists(err) { - return nil, trace.Wrap(err) + return trace.Wrap(err) } } else { log.Infof("Created trusted certificate authority: %q, type: %q.", ca.GetName(), ca.GetType()) @@ -316,7 +333,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e } for _, tunnel := range cfg.ReverseTunnels { if err := asrv.UpsertReverseTunnel(tunnel); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } log.Infof("Created reverse tunnel: %v.", tunnel) } @@ -398,7 +415,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e }) if err := g.Wait(); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } // Override user passed in cluster name with what is in the backend. @@ -407,7 +424,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e // Migrate Host CA as Database CA before certificates generation. Otherwise, the Database CA will be // generated which we don't want for existing installations. if err := migrateDBAuthority(ctx, asrv); err != nil { - return nil, trace.Wrap(err, "failed to migrate database CA") + return trace.Wrap(err, "failed to migrate database CA") } // generate certificate authorities if they don't exist @@ -506,7 +523,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e }) } if err := g.Wait(); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } // Delete any unused keys from the keyStore. This is to avoid exhausting @@ -528,14 +545,14 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e span.AddEvent("migrating legacy resources") // Migrate any legacy resources to new format. if err := migrateLegacyResources(ctx, asrv); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } span.AddEvent("completed migration legacy resources") span.AddEvent("creating presets") // Create presets - convenience and example resources. if err := createPresets(ctx, asrv); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } span.AddEvent("completed creating presets") @@ -546,7 +563,7 @@ func Init(ctx context.Context, cfg InitConfig, opts ...ServerOption) (*Server, e log.Infof("Auth server is skipping periodic operations.") } - return asrv, nil + return nil } // generateAuthority creates a new self-signed authority of the provided type diff --git a/lib/backend/helpers.go b/lib/backend/helpers.go index 5961dcc37be9a..892b6a27fe395 100644 --- a/lib/backend/helpers.go +++ b/lib/backend/helpers.go @@ -54,28 +54,55 @@ func randomID() ([]byte, error) { return bytes[:], nil } +type LockConfiguration struct { + Backend Backend + LockName string + // TTL defines when lock will be released automatically + TTL time.Duration + // RetryInterval defines interval which is used to retry locking after + // initial lock failed due to someone else holding lock. + RetryInterval time.Duration +} + +func (l *LockConfiguration) CheckAndSetDefaults() error { + if l.Backend == nil { + return trace.BadParameter("missing Backend") + } + if l.LockName == "" { + return trace.BadParameter("missing LockName") + } + if l.TTL == 0 { + return trace.BadParameter("missing TTL") + } + if l.RetryInterval == 0 { + l.RetryInterval = 250 * time.Millisecond + } + return nil +} + // AcquireLock grabs a lock that will be released automatically in TTL -func AcquireLock(ctx context.Context, backend Backend, lockName string, ttl time.Duration) (Lock, error) { - if lockName == "" { - return Lock{}, trace.BadParameter("missing parameter lock name") +func AcquireLock(ctx context.Context, cfg LockConfiguration) (Lock, error) { + err := cfg.CheckAndSetDefaults() + if err != nil { + return Lock{}, trace.Wrap(err) } - key := lockKey(lockName) + key := lockKey(cfg.LockName) id, err := randomID() if err != nil { return Lock{}, trace.Wrap(err) } for { // Get will clear TTL on a lock - backend.Get(ctx, key) + cfg.Backend.Get(ctx, key) // CreateVal is atomic: - _, err = backend.Create(ctx, Item{Key: key, Value: id, Expires: backend.Clock().Now().UTC().Add(ttl)}) + _, err = cfg.Backend.Create(ctx, Item{Key: key, Value: id, Expires: cfg.Backend.Clock().Now().UTC().Add(cfg.TTL)}) if err == nil { break // success } if trace.IsAlreadyExists(err) { // locked? wait and repeat: select { - case <-backend.Clock().After(250 * time.Millisecond): + case <-cfg.Backend.Clock().After(cfg.RetryInterval): // OK, go around and try again continue @@ -86,7 +113,7 @@ func AcquireLock(ctx context.Context, backend Backend, lockName string, ttl time } return Lock{}, trace.ConvertSystemError(err) } - return Lock{key: key, id: id, ttl: ttl}, nil + return Lock{key: key, id: id, ttl: cfg.TTL}, nil } // Release forces lock release @@ -134,22 +161,52 @@ func (l *Lock) resetTTL(ctx context.Context, backend Backend) error { return nil } +// RunWhileLockedConfig is configuration for RunWhileLocked function. +type RunWhileLockedConfig struct { + // LockConfiguration is configuration for acquire lock. + LockConfiguration + + // ReleaseCtxTimeout defines timeout used for calling lock.Release method (optional). + ReleaseCtxTimeout time.Duration + // RefreshLockInterval defines interval at which lock will be refreshed + // if fn is still running (optional). + RefreshLockInterval time.Duration +} + +func (c *RunWhileLockedConfig) CheckAndSetDefaults() error { + if err := c.LockConfiguration.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if c.ReleaseCtxTimeout <= 0 { + c.ReleaseCtxTimeout = 300 * time.Millisecond + } + if c.RefreshLockInterval <= 0 { + c.RefreshLockInterval = c.LockConfiguration.TTL / 2 + } + return nil +} + // RunWhileLocked allows you to run a function while a lock is held. -func RunWhileLocked(ctx context.Context, backend Backend, lockName string, ttl time.Duration, fn func(context.Context) error) error { - lock, err := AcquireLock(ctx, backend, lockName, ttl) +func RunWhileLocked(ctx context.Context, cfg RunWhileLockedConfig, fn func(context.Context) error) error { + if err := cfg.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + lock, err := AcquireLock(ctx, cfg.LockConfiguration) if err != nil { return trace.Wrap(err) } subContext, cancelFunction := context.WithCancel(ctx) + defer cancelFunction() stopRefresh := make(chan struct{}) go func() { - refreshAfter := ttl / 2 + refreshAfter := cfg.RefreshLockInterval for { select { - case <-backend.Clock().After(refreshAfter): - if err := lock.resetTTL(ctx, backend); err != nil { + case <-cfg.Backend.Clock().After(refreshAfter): + if err := lock.resetTTL(ctx, cfg.Backend); err != nil { cancelFunction() log.Errorf("%v", err) return @@ -163,7 +220,11 @@ func RunWhileLocked(ctx context.Context, backend Backend, lockName string, ttl t fnErr := fn(subContext) close(stopRefresh) - if err := lock.Release(ctx, backend); err != nil { + // lock.Release should be called with separate ctx. If someone cancels via ctx + // RunWhileLocked method, we want to at least try releasing lock. + releaseLockCtx, releaseLockCancel := context.WithTimeout(context.Background(), cfg.ReleaseCtxTimeout) + defer releaseLockCancel() + if err := lock.Release(releaseLockCtx, cfg.Backend); err != nil { return trace.NewAggregate(fnErr, err) } diff --git a/lib/backend/helpers_test.go b/lib/backend/helpers_test.go new file mode 100644 index 0000000000000..18c7e95d4ee2b --- /dev/null +++ b/lib/backend/helpers_test.go @@ -0,0 +1,162 @@ +/* +Copyright 2018 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package backend + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" +) + +func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) { + type mockBackend struct { + Backend + } + tests := []struct { + name string + in, want LockConfiguration + wantErr string + }{ + { + name: "minimum valid", + in: LockConfiguration{ + Backend: mockBackend{}, + LockName: "lock", + TTL: 30 * time.Second, + }, + want: LockConfiguration{ + Backend: mockBackend{}, + LockName: "lock", + TTL: 30 * time.Second, + RetryInterval: 250 * time.Millisecond, + }, + }, + { + name: "set RetryAcquireLockTimeout", + in: LockConfiguration{ + Backend: mockBackend{}, + LockName: "lock", + TTL: 30 * time.Second, + RetryInterval: 10 * time.Second, + }, + want: LockConfiguration{ + Backend: mockBackend{}, + LockName: "lock", + TTL: 30 * time.Second, + RetryInterval: 10 * time.Second, + }, + }, + { + name: "missing backend", + in: LockConfiguration{ + Backend: nil, + }, + wantErr: "missing Backend", + }, + { + name: "missing lock name", + in: LockConfiguration{ + Backend: mockBackend{}, + LockName: "", + }, + wantErr: "missing LockName", + }, + { + name: "missing TTL", + in: LockConfiguration{ + Backend: mockBackend{}, + LockName: "lock", + TTL: 0, + }, + wantErr: "missing TTL", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.in + err := cfg.CheckAndSetDefaults() + if tt.wantErr == "" { + require.NoError(t, err, "CheckAndSetDefaults return unexpected err") + require.Empty(t, cmp.Diff(tt.want, cfg)) + } else { + require.ErrorContains(t, err, tt.wantErr) + } + }) + } +} + +func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) { + type mockBackend struct { + Backend + } + lockName := "lock" + ttl := 1 * time.Minute + minimumValidConfig := RunWhileLockedConfig{ + LockConfiguration: LockConfiguration{ + Backend: mockBackend{}, + LockName: lockName, + TTL: ttl, + }, + } + tests := []struct { + name string + input func() RunWhileLockedConfig + want RunWhileLockedConfig + wantErr string + }{ + { + name: "minimum valid config", + input: func() RunWhileLockedConfig { + return minimumValidConfig + }, + want: RunWhileLockedConfig{ + LockConfiguration: LockConfiguration{ + Backend: mockBackend{}, + LockName: lockName, + TTL: ttl, + RetryInterval: 250 * time.Millisecond, + }, + ReleaseCtxTimeout: 300 * time.Millisecond, + // defaults to half of TTL. + RefreshLockInterval: 30 * time.Second, + }, + }, + { + name: "errors from LockConfiguration is passed", + input: func() RunWhileLockedConfig { + cfg := minimumValidConfig + cfg.LockName = "" + return cfg + }, + wantErr: "missing LockName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.input() + err := cfg.CheckAndSetDefaults() + if tt.wantErr == "" { + require.NoError(t, err, "CheckAndSetDefaults return unexpected err") + require.Empty(t, cmp.Diff(tt.want, cfg)) + } else { + require.ErrorContains(t, err, tt.wantErr) + } + }) + } +} diff --git a/lib/backend/test/suite.go b/lib/backend/test/suite.go index bd70e00f70a18..47ef61bdba1b2 100644 --- a/lib/backend/test/suite.go +++ b/lib/backend/test/suite.go @@ -850,7 +850,7 @@ func testLocking(t *testing.T, newBackend Constructor) { defer requireNoAsyncErrors() // Given a lock named `tok1` on the backend... - lock, err := backend.AcquireLock(ctx, uut, tok1, ttl) + lock, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) require.NoError(t, err) // When I asynchronously release the lock... @@ -865,7 +865,7 @@ func testLocking(t *testing.T, newBackend Constructor) { }() // ...and simultaneously attempt to create a new lock with the same name - lock, err = backend.AcquireLock(ctx, uut, tok1, ttl) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) // expect that the asynchronous Release() has executed - we're using the // change in the value of the marker value as a proxy for the Release(). @@ -877,7 +877,7 @@ func testLocking(t *testing.T, newBackend Constructor) { require.NoError(t, lock.Release(ctx, uut)) // Given a lock with the same name as previously-existing, manually-released lock - lock, err = backend.AcquireLock(ctx, uut, tok1, ttl) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) require.NoError(t, err) atomic.StoreInt32(&marker, 7) @@ -892,7 +892,7 @@ func testLocking(t *testing.T, newBackend Constructor) { }() // ...and simultaneously try to acquire another lock with the same name - lock, err = backend.AcquireLock(ctx, uut, tok1, ttl) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) // expect that the asynchronous Release() has executed - we're using the // change in the value of the marker value as a proxy for the call to @@ -906,9 +906,9 @@ func testLocking(t *testing.T, newBackend Constructor) { // Given a pair of locks named `tok1` and `tok2` y := int32(0) - lock1, err := backend.AcquireLock(ctx, uut, tok1, ttl) + lock1, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) require.NoError(t, err) - lock2, err := backend.AcquireLock(ctx, uut, tok2, ttl) + lock2, err := backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok2, TTL: ttl}) require.NoError(t, err) // When I asynchronously release the locks... @@ -925,7 +925,7 @@ func testLocking(t *testing.T, newBackend Constructor) { } }() - lock, err = backend.AcquireLock(ctx, uut, tok1, ttl) + lock, err = backend.AcquireLock(ctx, backend.LockConfiguration{Backend: uut, LockName: tok1, TTL: ttl}) require.NoError(t, err) require.Equal(t, int32(15), atomic.LoadInt32(&y)) require.NoError(t, lock.Release(ctx, uut)) diff --git a/lib/services/local/access.go b/lib/services/local/access.go index 3a0fc1c144fdc..8ada6d97fbb97 100644 --- a/lib/services/local/access.go +++ b/lib/services/local/access.go @@ -238,7 +238,13 @@ func (s *AccessService) DeleteAllLocks(ctx context.Context) error { // ReplaceRemoteLocks replaces the set of locks associated with a remote cluster. func (s *AccessService) ReplaceRemoteLocks(ctx context.Context, clusterName string, newRemoteLocks []types.Lock) error { - return backend.RunWhileLocked(ctx, s.Backend, "ReplaceRemoteLocks/"+clusterName, time.Minute, func(ctx context.Context) error { + return backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: s.Backend, + LockName: "ReplaceRemoteLocks/" + clusterName, + TTL: time.Minute, + }, + }, func(ctx context.Context) error { remoteLocksKey := backend.Key(locksPrefix, clusterName) origRemoteLocks, err := s.GetRange(ctx, remoteLocksKey, backend.RangeEnd(remoteLocksKey), backend.NoLimit) if err != nil { diff --git a/lib/services/local/generic/generic.go b/lib/services/local/generic/generic.go index 72d6d9266d285..4905fa8e15787 100644 --- a/lib/services/local/generic/generic.go +++ b/lib/services/local/generic/generic.go @@ -286,7 +286,14 @@ func (s *Service[T]) MakeKey(name string) []byte { // RunWhileLocked will run the given function in a backend lock. This is a wrapper around the backend.RunWhileLocked function. func (s *Service[T]) RunWhileLocked(ctx context.Context, lockName string, ttl time.Duration, fn func(context.Context, backend.Backend) error) error { - return trace.Wrap(backend.RunWhileLocked(ctx, s.backend, lockName, ttl, func(ctx context.Context) error { - return fn(ctx, s.backend) - })) + return trace.Wrap(backend.RunWhileLocked(ctx, + backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: s.backend, + LockName: lockName, + TTL: ttl, + }, + }, func(ctx context.Context) error { + return fn(ctx, s.backend) + })) }