diff --git a/integration/hsm/helpers.go b/integration/hsm/helpers.go new file mode 100644 index 0000000000000..15578bd54d48a --- /dev/null +++ b/integration/hsm/helpers.go @@ -0,0 +1,284 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hsm + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" +) + +// teleportService wraps a *service.TeleportProcess and sets up a goroutine to +// handle process reloads. You must always call waitForNewProcess or +// waitForRestart in for the new process after an expected reload to be picked +// up. Methods are not meant to be called concurrently on the same receiver and +// are not generally thread safe. +type teleportService struct { + name string + log utils.Logger + config *service.Config + process *service.TeleportProcess + processGeneration int + serviceChannel chan *service.TeleportProcess + errorChannel chan error +} + +func newTeleportService(t *testing.T, config *service.Config, name string) *teleportService { + s := &teleportService{ + config: config, + name: name, + log: config.Log, + serviceChannel: make(chan *service.TeleportProcess, 1), + errorChannel: make(chan error, 1), + } + t.Cleanup(func() { + require.NoError(t, s.close(), "error while closing %s during test cleanup", name) + }) + return s +} + +func (t *teleportService) close() error { + if t.process == nil { + return nil + } + if err := t.process.Close(); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(t.process.Wait()) +} + +func (t *teleportService) start(ctx context.Context) error { + // Run the service in a background goroutine and hook into service.Run to + // receive all new processes after restarts and write them to a goroutine. + go func() { + t.errorChannel <- service.Run(ctx, *t.config, func(cfg *service.Config) (service.Process, error) { + t.log.Debugf("%s gen %d: starting next process generation (gen %d)", t.name, t.processGeneration, t.processGeneration+1) + svc, err := service.NewTeleport(cfg) + if err == nil { + t.log.Debugf("%s gen %d: started, writing to serviceChannel", t.name, t.processGeneration+1) + t.serviceChannel <- svc + } + return svc, trace.Wrap(err) + }) + }() + t.log.Debugf("%s gen 1: waiting for first start", t.name) + if err := t.waitForNewProcess(ctx); err != nil { + return trace.Wrap(err) + } + t.log.Debugf("%s gen 1: started, waiting for it to be ready", t.name) + return t.waitForReady(ctx) +} + +func (t *teleportService) waitForNewProcess(ctx context.Context) error { + select { + case t.process = <-t.serviceChannel: + t.processGeneration += 1 + t.log.Debugf("%s gen %d: received new process from serviceChannel", t.name, t.processGeneration) + case err := <-t.errorChannel: + return trace.Wrap(err) + case <-ctx.Done(): + return trace.Wrap(ctx.Err(), "%s gen %d: timed out waiting for restart", t.name, t.processGeneration) + } + return nil +} + +func (t *teleportService) waitForReady(ctx context.Context) error { + t.log.Debugf("%s gen %d: waiting for TeleportReadyEvent", t.name, t.processGeneration) + if _, err := t.process.WaitForEvent(ctx, service.TeleportReadyEvent); err != nil { + return trace.Wrap(err, "timed out waiting for %s gen %d to be ready", t.name, t.processGeneration) + } + t.log.Debugf("%s gen %d: got TeleportReadyEvent", t.name, t.processGeneration) + // If this is an Auth server, also wait for AuthIdentityEvent so that we + // can safely read the admin credentials and create a test client. + if t.process.GetAuthServer() != nil { + t.log.Debugf("%s gen %d: waiting for AuthIdentityEvent", t.name, t.processGeneration) + if _, err := t.process.WaitForEvent(ctx, service.AuthIdentityEvent); err != nil { + return trace.Wrap(err, "%s gen %d: timed out waiting AuthIdentityEvent", t.name, t.processGeneration) + } + t.log.Debugf("%s gen %d: got AuthIdentityEvent", t.name, t.processGeneration) + } + return nil +} + +func (t *teleportService) waitForRestart(ctx context.Context) error { + t.log.Debugf("%s gen %d: waiting for restart", t.name, t.processGeneration) + if err := t.waitForNewProcess(ctx); err != nil { + return trace.Wrap(err) + } + t.log.Debugf("%s gen %d: restarted, waiting for new process (gen %d) to be ready", t.name, t.processGeneration-1, t.processGeneration) + return trace.Wrap(t.waitForReady(ctx)) +} + +func (t *teleportService) waitForShutdown(ctx context.Context) error { + t.log.Debugf("%s gen %d: waiting for shutdown", t.name, t.processGeneration) + select { + case err := <-t.errorChannel: + t.process = nil + return trace.Wrap(err) + case <-ctx.Done(): + return trace.Wrap(ctx.Err(), "%s gen %d: timed out waiting for shutdown", t.name, t.processGeneration) + } +} + +func (t *teleportService) waitForLocalAdditionalKeys(ctx context.Context) error { + t.log.Debugf("%s gen %d: waiting for local additional keys", t.name, t.processGeneration) + clusterName, err := t.process.GetAuthServer().GetClusterName() + if err != nil { + return trace.Wrap(err) + } + hostCAID := types.CertAuthID{DomainName: clusterName.GetClusterName(), Type: types.HostCA} + for { + select { + case <-ctx.Done(): + return trace.Wrap(ctx.Err(), "%s gen %d: timed out waiting for local additional keys", t.name, t.processGeneration) + case <-time.After(250 * time.Millisecond): + } + ca, err := t.process.GetAuthServer().GetCertAuthority(ctx, hostCAID, true) + if err != nil { + return trace.Wrap(err) + } + hasUsableKeys, err := t.process.GetAuthServer().GetKeyStore().HasUsableAdditionalKeys(ctx, ca) + if err != nil { + return trace.Wrap(err) + } + if hasUsableKeys { + break + } + } + t.log.Debugf("%s gen %d has local additional keys", t.name, t.processGeneration) + return nil +} + +func (t *teleportService) waitForPhaseChange(ctx context.Context) error { + t.log.Debugf("%s gen %d: waiting for phase change", t.name, t.processGeneration) + if _, err := t.process.WaitForEvent(ctx, service.TeleportPhaseChangeEvent); err != nil { + return trace.Wrap(err, "%s gen %d: timed out waiting for phase change", t.name, t.processGeneration) + } + t.log.Debugf("%s gen %d: changed phase", t.name, t.processGeneration) + return nil +} + +func (t *teleportService) authAddr(testingT *testing.T) utils.NetAddr { + addr, err := t.process.AuthAddr() + require.NoError(testingT, err) + + return *addr +} + +func (t *teleportService) authAddrString(testingT *testing.T) string { + addr, err := t.process.AuthAddr() + require.NoError(testingT, err) + return addr.String() +} + +type teleportServices []*teleportService + +func (s teleportServices) forEach(f func(t *teleportService) error) error { + for i := range s { + if err := f(s[i]); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +func (s teleportServices) start(ctx context.Context) error { + return s.forEach(func(t *teleportService) error { return t.start(ctx) }) +} + +func (s teleportServices) waitForRestart(ctx context.Context) error { + return s.forEach(func(t *teleportService) error { return t.waitForRestart(ctx) }) +} + +func (s teleportServices) waitForLocalAdditionalKeys(ctx context.Context) error { + return s.forEach(func(t *teleportService) error { return t.waitForLocalAdditionalKeys(ctx) }) +} + +func (s teleportServices) waitForPhaseChange(ctx context.Context) error { + return s.forEach(func(t *teleportService) error { return t.waitForPhaseChange(ctx) }) +} + +func newAuthConfig(t *testing.T, log utils.Logger) *service.Config { + config := service.MakeDefaultConfig() + config.DataDir = t.TempDir() + config.Auth.StorageConfig.Params["path"] = filepath.Join(config.DataDir, defaults.BackendDir) + config.SSH.Enabled = false + config.Proxy.Enabled = false + config.Log = log + config.InstanceMetadataClient = cloud.NewDisabledIMDSClient() + config.MaxRetryPeriod = 25 * time.Millisecond + config.PollingPeriod = 2 * time.Second + + config.Auth.Enabled = true + config.Auth.NoAudit = true + config.Auth.ListenAddr.Addr = "localhost:0" + config.Auth.PublicAddrs = []utils.NetAddr{ + { + AddrNetwork: "tcp", + Addr: "localhost", + }, + } + var err error + config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ + ClusterName: "testcluster", + }) + require.NoError(t, err) + config.SetAuthServerAddress(config.Auth.ListenAddr) + config.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ + StaticTokens: []types.ProvisionTokenV1{ + { + Roles: []types.SystemRole{"Proxy", "Node"}, + Token: "foo", + }, + }, + }) + require.NoError(t, err) + + return config +} + +func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *service.Config { + config := service.MakeDefaultConfig() + config.DataDir = t.TempDir() + config.CachePolicy.Enabled = true + config.Auth.Enabled = false + config.SSH.Enabled = false + config.SetToken("foo") + config.SetAuthServerAddress(authAddr) + config.Log = log + config.InstanceMetadataClient = cloud.NewDisabledIMDSClient() + config.MaxRetryPeriod = 25 * time.Millisecond + config.PollingPeriod = 2 * time.Second + + config.Proxy.Enabled = true + config.Proxy.DisableWebInterface = true + config.Proxy.DisableWebService = true + config.Proxy.DisableReverseTunnel = true + config.Proxy.SSHAddr.Addr = "localhost:0" + config.Proxy.WebAddr.Addr = "localhost:0" + + return config +} diff --git a/integration/hsm/hsm_test.go b/integration/hsm/hsm_test.go index 970d3ccd5e471..1262f6f44f639 100644 --- a/integration/hsm/hsm_test.go +++ b/integration/hsm/hsm_test.go @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package integration +package hsm import ( "context" - "net" "os" "path/filepath" "testing" @@ -24,6 +23,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" @@ -35,10 +35,8 @@ import ( "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/etcdbk" "github.com/gravitational/teleport/lib/backend/lite" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/service" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) @@ -56,258 +54,10 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -type teleportService struct { - name string - log utils.Logger - config *service.Config - process *service.TeleportProcess - serviceChannel chan *service.TeleportProcess - errorChannel chan error -} - -func newTeleportService(t *testing.T, config *service.Config, name string) *teleportService { - s := &teleportService{ - config: config, - name: name, - log: config.Log, - serviceChannel: make(chan *service.TeleportProcess, 1), - errorChannel: make(chan error, 1), - } - t.Cleanup(func() { - require.NoError(t, s.Close(), "error while closing %s during test cleanup", name) - }) - return s -} - -func (t *teleportService) Close() error { - if t.process == nil { - return nil - } - if err := t.process.Close(); err != nil { - return trace.Wrap(err) - } - return trace.Wrap(t.process.Wait()) -} - -func (t *teleportService) start(ctx context.Context) { - go func() { - t.errorChannel <- service.Run(ctx, *t.config, func(cfg *service.Config) (service.Process, error) { - t.log.Debugf("(Re)starting %s", t.name) - svc, err := service.NewTeleport(cfg) - if err == nil { - t.log.Debugf("started %s, writing to serviceChannel", t.name) - t.serviceChannel <- svc - } - return svc, trace.Wrap(err) - }) - }() -} - -func (t *teleportService) waitForStart(ctx context.Context) error { - t.log.Debugf("Waiting for %s to start", t.name) - t.start(ctx) - select { - case t.process = <-t.serviceChannel: - case err := <-t.errorChannel: - return trace.Wrap(err) - case <-ctx.Done(): - return trace.Wrap(ctx.Err(), "timed out waiting for %s to start", t.name) - } - t.log.Debugf("read %s from serviceChannel", t.name) - return t.waitForReady(ctx) -} - -func (t *teleportService) waitForReady(ctx context.Context) error { - t.log.Debugf("Waiting for %s to be ready", t.name) - if _, err := t.process.WaitForEvent(ctx, service.TeleportReadyEvent); err != nil { - return trace.Wrap(err, "timed out waiting for %s to be ready", t.name) - } - // also wait for AuthIdentityEvent so that we can read the admin credentials - // and create a test client - if t.process.GetAuthServer() != nil { - if _, err := t.process.WaitForEvent(ctx, service.AuthIdentityEvent); err != nil { - return trace.Wrap(err, "timed out waiting for %s auth identity event", t.name) - } - t.log.Debugf("%s is ready", t.name) - } - return nil -} - -func (t *teleportService) waitForRestart(ctx context.Context) error { - t.log.Debugf("Waiting for %s to restart", t.name) - // get the new process - select { - case t.process = <-t.serviceChannel: - case err := <-t.errorChannel: - return trace.Wrap(err) - case <-ctx.Done(): - return trace.Wrap(ctx.Err(), "timed out waiting for %s to restart", t.name) - } - - // wait for the new process to be ready - err := t.waitForReady(ctx) - if err != nil { - return trace.Wrap(err) - } - t.log.Debugf("%s successfully restarted", t.name) - return nil -} - -func (t *teleportService) waitForShutdown(ctx context.Context) error { - t.log.Debugf("Waiting for %s to shut down", t.name) - select { - case err := <-t.errorChannel: - t.process = nil - return trace.Wrap(err) - case <-ctx.Done(): - return trace.Wrap(ctx.Err(), "timed out waiting for %s to shut down", t.name) - } -} - -func (t *teleportService) waitForLocalAdditionalKeys(ctx context.Context) error { - t.log.Debugf("Waiting for %s to have local additional keys", t.name) - clusterName, err := t.process.GetAuthServer().GetClusterName() - if err != nil { - return trace.Wrap(err) - } - hostCAID := types.CertAuthID{DomainName: clusterName.GetClusterName(), Type: types.HostCA} - for { - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err(), "timed out waiting for %s to have local additional keys", t.name) - case <-time.After(250 * time.Millisecond): - } - ca, err := t.process.GetAuthServer().GetCertAuthority(ctx, hostCAID, true) - if err != nil { - return trace.Wrap(err) - } - hasUsableKeys, err := t.process.GetAuthServer().GetKeyStore().HasUsableAdditionalKeys(ctx, ca) - if err != nil { - return trace.Wrap(err) - } - if hasUsableKeys { - break - } - } - t.log.Debugf("%s has local additional keys", t.name) - return nil -} - -func (t *teleportService) waitForPhaseChange(ctx context.Context) error { - t.log.Debugf("Waiting for %s to change phase", t.name) - if _, err := t.process.WaitForEvent(ctx, service.TeleportPhaseChangeEvent); err != nil { - return trace.Wrap(err, "timed out waiting for %s to change phase", t.name) - } - t.log.Debugf("%s changed phase", t.name) - return nil -} - -func (t *teleportService) AuthAddr(testingT *testing.T) utils.NetAddr { - addr, err := t.process.AuthAddr() - require.NoError(testingT, err) - - return *addr -} - -type TeleportServices []*teleportService - -func (s TeleportServices) forEach(f func(t *teleportService) error) error { - for i := range s { - if err := f(s[i]); err != nil { - return trace.Wrap(err) - } - } - return nil -} - -func (s TeleportServices) waitForStart(ctx context.Context) error { - return s.forEach(func(t *teleportService) error { return t.waitForStart(ctx) }) -} - -func (s TeleportServices) waitForRestart(ctx context.Context) error { - return s.forEach(func(t *teleportService) error { return t.waitForRestart(ctx) }) -} - -func (s TeleportServices) waitForLocalAdditionalKeys(ctx context.Context) error { - return s.forEach(func(t *teleportService) error { return t.waitForLocalAdditionalKeys(ctx) }) -} - -func (s TeleportServices) waitForPhaseChange(ctx context.Context) error { - return s.forEach(func(t *teleportService) error { return t.waitForPhaseChange(ctx) }) -} - -func newHSMAuthConfig(ctx context.Context, t *testing.T, storageConfig *backend.Config, log utils.Logger) *service.Config { - hostName, err := os.Hostname() - require.NoError(t, err) - - config := service.MakeDefaultConfig() - config.PollingPeriod = 1 * time.Second - config.SSH.Enabled = false - config.Proxy.Enabled = false - config.Testing.ClientTimeout = time.Second - config.Testing.ShutdownTimeout = time.Minute - config.DataDir = t.TempDir() - config.Auth.ListenAddr.Addr = net.JoinHostPort(hostName, "0") - config.Auth.PublicAddrs = []utils.NetAddr{ - { - AddrNetwork: "tcp", - Addr: hostName, - }, - } - config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ - ClusterName: "testcluster", - }) - require.NoError(t, err) - config.SetAuthServerAddress(config.Auth.ListenAddr) - config.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ - StaticTokens: []types.ProvisionTokenV1{ - { - Roles: []types.SystemRole{"Proxy", "Node"}, - Token: "foo", - }, - }, - }) - require.NoError(t, err) - config.Log = log - if storageConfig != nil { - config.Auth.StorageConfig = *storageConfig - } - config.CircuitBreakerConfig = breaker.NoopBreakerConfig() - config.InstanceMetadataClient = cloud.NewDisabledIMDSClient() - - if gcpKeyring := os.Getenv("TEST_GCP_KMS_KEYRING"); gcpKeyring != "" { - config.Auth.KeyStore.GCPKMS.KeyRing = gcpKeyring - config.Auth.KeyStore.GCPKMS.ProtectionLevel = "HSM" - } else { - config.Auth.KeyStore = keystore.SetupSoftHSMTest(t) - } - - return config -} - -func newProxyConfig(ctx context.Context, t *testing.T, authAddr utils.NetAddr, log utils.Logger) *service.Config { - hostName, err := os.Hostname() - require.NoError(t, err) - - config := service.MakeDefaultConfig() - config.PollingPeriod = 1 * time.Second - config.SetToken("foo") - config.SSH.Enabled = false - config.Auth.Enabled = false - config.Proxy.Enabled = true - config.Proxy.DisableWebInterface = true - config.Proxy.DisableWebService = true - config.Proxy.DisableReverseTunnel = true - config.Proxy.SSHAddr.Addr = net.JoinHostPort(hostName, "0") - config.Proxy.WebAddr.Addr = net.JoinHostPort(hostName, "0") - config.CachePolicy.Enabled = true - config.PollingPeriod = 1 * time.Second - config.Testing.ShutdownTimeout = time.Minute - config.DataDir = t.TempDir() - config.SetAuthServerAddress(authAddr) - config.CircuitBreakerConfig = breaker.NoopBreakerConfig() - config.InstanceMetadataClient = cloud.NewDisabledIMDSClient() - config.Log = log +func newHSMAuthConfig(t *testing.T, storageConfig *backend.Config, log utils.Logger) *service.Config { + config := newAuthConfig(t, log) + config.Auth.StorageConfig = *storageConfig + config.Auth.KeyStore = keystore.HSMTestConfig(t) return config } @@ -326,8 +76,12 @@ func etcdBackendConfig(t *testing.T) *backend.Config { t.Cleanup(func() { bk, err := etcdbk.New(context.Background(), cfg.Params) require.NoError(t, err) - require.NoError(t, bk.DeleteRange(context.Background(), []byte(prefix), - backend.RangeEnd([]byte(prefix))), + + // Based on [backend.Sanitizer] these define the possible range that + // needs to be cleaned up at the end of the test. + firstPossibleKey := []byte("+") + lastPossibleKey := backend.RangeEnd([]byte("z")) + require.NoError(t, bk.DeleteRange(context.Background(), firstPossibleKey, lastPossibleKey), "failed to clean up etcd backend") }) return cfg @@ -351,12 +105,6 @@ func liteBackendConfig(t *testing.T) *backend.Config { } } -func requireHSMAvailable(t *testing.T) { - if os.Getenv("SOFTHSM2_PATH") == "" && os.Getenv("TEST_GCP_KMS_KEYRING") == "" { - t.Skip("Skipping test because neither SOFTHSM2_PATH or TEST_GCP_KMS_KEYRING are set") - } -} - func requireETCDAvailable(t *testing.T) { if os.Getenv("TELEPORT_ETCD_TEST") == "" { t.Skip("Skipping test because TELEPORT_ETCD_TEST is not set") @@ -365,29 +113,26 @@ func requireETCDAvailable(t *testing.T) { // Tests a single CA rotation with a single HSM auth server func TestHSMRotation(t *testing.T) { - requireHSMAvailable(t) - - // pick a conservative timeout - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) log := utils.NewLoggerForTests() log.Debug("TestHSMRotation: starting auth server") - authConfig := newHSMAuthConfig(ctx, t, liteBackendConfig(t), log) + authConfig := newHSMAuthConfig(t, liteBackendConfig(t), log) auth1 := newTeleportService(t, authConfig, "auth1") t.Cleanup(func() { require.NoError(t, auth1.process.GetAuthServer().GetKeyStore().DeleteUnusedKeys(ctx, nil)) }) - teleportServices := TeleportServices{auth1} + allServices := teleportServices{auth1} log.Debug("TestHSMRotation: waiting for auth server to start") - require.NoError(t, auth1.waitForStart(ctx)) + require.NoError(t, auth1.start(ctx)) // start a proxy to make sure it can get creds at each stage of rotation log.Debug("TestHSMRotation: starting proxy") - proxy := newTeleportService(t, newProxyConfig(ctx, t, auth1.AuthAddr(t), log), "proxy") - require.NoError(t, proxy.waitForStart(ctx)) - teleportServices = append(teleportServices, proxy) + proxy := newTeleportService(t, newProxyConfig(t, auth1.authAddr(t), log), "proxy") + require.NoError(t, proxy.start(ctx)) + allServices = append(allServices, proxy) log.Debug("TestHSMRotation: sending rotation request init") err := auth1.process.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -396,7 +141,7 @@ func TestHSMRotation(t *testing.T) { Mode: types.RotationModeManual, }) require.NoError(t, err) - require.NoError(t, teleportServices.waitForPhaseChange(ctx)) + require.NoError(t, allServices.waitForPhaseChange(ctx)) log.Debug("TestHSMRotation: sending rotation request update_clients") err = auth1.process.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -405,7 +150,7 @@ func TestHSMRotation(t *testing.T) { Mode: types.RotationModeManual, }) require.NoError(t, err) - require.NoError(t, teleportServices.waitForRestart(ctx)) + require.NoError(t, allServices.waitForRestart(ctx)) log.Debug("TestHSMRotation: sending rotation request update_servers") err = auth1.process.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -414,7 +159,7 @@ func TestHSMRotation(t *testing.T) { Mode: types.RotationModeManual, }) require.NoError(t, err) - require.NoError(t, teleportServices.waitForRestart(ctx)) + require.NoError(t, allServices.waitForRestart(ctx)) log.Debug("TestHSMRotation: sending rotation request standby") err = auth1.process.GetAuthServer().RotateCertAuthority(ctx, auth.RotateRequest{ @@ -423,7 +168,47 @@ func TestHSMRotation(t *testing.T) { Mode: types.RotationModeManual, }) require.NoError(t, err) - require.NoError(t, teleportServices.waitForRestart(ctx)) + require.NoError(t, allServices.waitForRestart(ctx)) +} + +func getAdminClient(authDataDir string, authAddr string) (*auth.Client, error) { + identity, err := auth.ReadLocalIdentity( + filepath.Join(authDataDir, teleport.ComponentProcess), + auth.IdentityID{Role: types.RoleAdmin}) + if err != nil { + return nil, trace.Wrap(err) + } + + tlsConfig, err := identity.TLSConfig(nil /*cipherSuites*/) + if err != nil { + return nil, trace.Wrap(err) + } + + clt, err := auth.NewClient(client.Config{ + Addrs: []string{authAddr}, + Credentials: []client.Credentials{ + client.LoadTLS(tlsConfig), + }, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + }) + return clt, trace.Wrap(err) +} + +func testAdminClient(t *testing.T, authDataDir string, authAddr string) { + require.EventuallyWithT(t, func(t *assert.CollectT) { + clt, err := getAdminClient(authDataDir, authAddr) + assert.NoError(t, err) + if err != nil { + return + } + // Make sure it succeeds twice in a row, we might be hitting a load + // balancer in front of two auths, this gives a better chance of testing + // both + for i := 0; i < 2; i++ { + _, err := clt.GetClusterName() + assert.NoError(t, err) + } + }, 10*time.Second, time.Second, "admin client failed test call to GetClusterName") } // Tests multiple CA rotations and rollbacks with 2 HSM auth servers in an HA configuration @@ -432,34 +217,30 @@ func TestHSMDualAuthRotation(t *testing.T) { // https://github.com/gravitational/teleport/issues/20217 t.Skip("TestHSMDualAuthRotation is temporarily disabled due to flakiness") - requireHSMAvailable(t) requireETCDAvailable(t) - // pick a global timeout for the test - ctx, cancel := context.WithTimeout(context.Background(), 8*time.Minute) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) log := utils.NewLoggerForTests() storageConfig := etcdBackendConfig(t) // start a cluster with 1 auth server and a proxy log.Debug("TestHSMDualAuthRotation: Starting auth server 1") - auth1Config := newHSMAuthConfig(ctx, t, storageConfig, log) + auth1Config := newHSMAuthConfig(t, storageConfig, log) auth1 := newTeleportService(t, auth1Config, "auth1") t.Cleanup(func() { require.NoError(t, auth1.process.GetAuthServer().GetKeyStore().DeleteUnusedKeys(ctx, nil), "failed to delete hsm keys during test cleanup") }) - authServices := TeleportServices{auth1} - teleportServices := append(TeleportServices{}, authServices...) - require.NoError(t, authServices.waitForStart(ctx), "auth service failed initial startup") + authServices := teleportServices{auth1} + allServices := append(teleportServices{}, authServices...) + require.NoError(t, authServices.start(ctx), "auth service failed initial startup") log.Debug("TestHSMDualAuthRotation: Starting load balancer") - hostName, err := os.Hostname() - require.NoError(t, err) lb, err := utils.NewLoadBalancer( ctx, - *utils.MustParseAddr(net.JoinHostPort(hostName, "0")), - auth1.AuthAddr(t), + *utils.MustParseAddr("localhost:0"), + auth1.authAddr(t), ) require.NoError(t, err) require.NoError(t, lb.Listen()) @@ -468,47 +249,26 @@ func TestHSMDualAuthRotation(t *testing.T) { // start a proxy to make sure it can get creds at each stage of rotation log.Debug("TestHSMDualAuthRotation: Starting proxy") - proxyConfig := newProxyConfig(ctx, t, utils.FromAddr(lb.Addr()), log) + proxyConfig := newProxyConfig(t, utils.FromAddr(lb.Addr()), log) proxy := newTeleportService(t, proxyConfig, "proxy") - require.NoError(t, proxy.waitForStart(ctx), "proxy failed initial startup") - teleportServices = append(teleportServices, proxy) + require.NoError(t, proxy.start(ctx), "proxy failed initial startup") + allServices = append(allServices, proxy) // add a new auth server log.Debug("TestHSMDualAuthRotation: Starting auth server 2") - auth2Config := newHSMAuthConfig(ctx, t, storageConfig, log) + auth2Config := newHSMAuthConfig(t, storageConfig, log) auth2 := newTeleportService(t, auth2Config, "auth2") - require.NoError(t, auth2.waitForStart(ctx)) + require.NoError(t, auth2.start(ctx)) t.Cleanup(func() { require.NoError(t, auth2.process.GetAuthServer().GetKeyStore().DeleteUnusedKeys(ctx, nil)) }) authServices = append(authServices, auth2) - teleportServices = append(teleportServices, auth2) + allServices = append(allServices, auth2) - // make sure the admin identity used by tctl works - getAdminClient := func() *auth.Client { - identity, err := auth.ReadLocalIdentity( - filepath.Join(auth2Config.DataDir, teleport.ComponentProcess), - auth.IdentityID{Role: types.RoleAdmin, HostUUID: auth2Config.HostUUID}) - require.NoError(t, err) - tlsConfig, err := identity.TLSConfig(nil) - require.NoError(t, err) - authAddrs := []utils.NetAddr{auth2.AuthAddr(t)} - clt, err := auth.NewClient(client.Config{ - Addrs: utils.NetAddrsToStrings(authAddrs), - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - }) - require.NoError(t, err) - return clt + testAuth2Client := func(t *testing.T) { + testAdminClient(t, auth2Config.DataDir, auth2.authAddrString(t)) } - testClient := func(clt *auth.Client) error { - _, err = clt.GetClusterName() - return trace.Wrap(err) - } - clt := getAdminClient() - require.NoError(t, testClient(clt)) + testAuth2Client(t) stages := []struct { targetPhase string @@ -517,34 +277,30 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForPhaseChange(ctx)) + require.NoError(t, allServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + testAuth2Client(t) }, }, { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testAuth2Client(t) }, }, { targetPhase: types.RotationPhaseUpdateServers, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testAuth2Client(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testAuth2Client(t) }, }, } @@ -561,34 +317,12 @@ func TestHSMDualAuthRotation(t *testing.T) { } // Safe to send traffic to new auth server now that a full rotation has been completed. - lb.AddBackend(auth2.AuthAddr(t)) + lb.AddBackend(auth2.authAddr(t)) - // load balanced client shoud work with either backend - getAdminClient = func() *auth.Client { - identity, err := auth.ReadLocalIdentity( - filepath.Join(auth2Config.DataDir, teleport.ComponentProcess), - auth.IdentityID{Role: types.RoleAdmin, HostUUID: auth2Config.HostUUID}) - require.NoError(t, err) - tlsConfig, err := identity.TLSConfig(nil) - require.NoError(t, err) - authAddrs := []string{lb.Addr().String()} - clt, err := auth.NewClient(client.Config{ - Addrs: authAddrs, - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - }) - require.NoError(t, err) - return clt + testLoadBalancedClient := func(t *testing.T) { + testAdminClient(t, auth2Config.DataDir, lb.Addr().String()) } - testClient = func(clt *auth.Client) error { - _, err1 := clt.GetClusterName() - _, err2 := clt.GetClusterName() - return trace.NewAggregate(err1, err2) - } - clt = getAdminClient() - require.NoError(t, testClient(clt)) + testLoadBalancedClient(t) // Do another full rotation from the new auth server for _, stage := range stages { @@ -609,100 +343,88 @@ func TestHSMDualAuthRotation(t *testing.T) { { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForPhaseChange(ctx)) + require.NoError(t, allServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseRollback, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForPhaseChange(ctx)) + require.NoError(t, allServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseRollback, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForPhaseChange(ctx)) + require.NoError(t, allServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseUpdateServers, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseRollback, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testLoadBalancedClient(t) }, }, } @@ -719,34 +441,30 @@ func TestHSMDualAuthRotation(t *testing.T) { // Tests a dual-auth server migration from raw keys to HSM keys func TestHSMMigrate(t *testing.T) { - requireHSMAvailable(t) requireETCDAvailable(t) - // pick a global timeout for the test - ctx, cancel := context.WithTimeout(context.Background(), 8*time.Minute) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) log := utils.NewLoggerForTests() storageConfig := etcdBackendConfig(t) // start a dual auth non-hsm cluster log.Debug("TestHSMMigrate: Starting auth server 1") - auth1Config := newHSMAuthConfig(ctx, t, storageConfig, log) + auth1Config := newHSMAuthConfig(t, storageConfig, log) auth1Config.Auth.KeyStore = keystore.Config{} auth1 := newTeleportService(t, auth1Config, "auth1") - auth2Config := newHSMAuthConfig(ctx, t, storageConfig, log) + auth2Config := newHSMAuthConfig(t, storageConfig, log) auth2Config.Auth.KeyStore = keystore.Config{} auth2 := newTeleportService(t, auth2Config, "auth2") - require.NoError(t, auth1.waitForStart(ctx)) - require.NoError(t, auth2.waitForStart(ctx)) + require.NoError(t, auth1.start(ctx)) + require.NoError(t, auth2.start(ctx)) log.Debug("TestHSMMigrate: Starting load balancer") - hostName, err := os.Hostname() - require.NoError(t, err) lb, err := utils.NewLoadBalancer( ctx, - *utils.MustParseAddr(net.JoinHostPort(hostName, "0")), - auth1.AuthAddr(t), - auth2.AuthAddr(t), + *utils.MustParseAddr("localhost:0"), + auth1.authAddr(t), + auth2.authAddr(t), ) require.NoError(t, err) require.NoError(t, lb.Listen()) @@ -755,50 +473,27 @@ func TestHSMMigrate(t *testing.T) { // start a proxy to make sure it can get creds at each stage of migration log.Debug("TestHSMMigrate: Starting proxy") - proxyConfig := newProxyConfig(ctx, t, utils.FromAddr(lb.Addr()), log) + proxyConfig := newProxyConfig(t, utils.FromAddr(lb.Addr()), log) proxy := newTeleportService(t, proxyConfig, "proxy") - require.NoError(t, proxy.waitForStart(ctx)) + require.NoError(t, proxy.start(ctx)) - // make sure the admin identity used by tctl works - getAdminClient := func() *auth.Client { - identity, err := auth.ReadLocalIdentity( - filepath.Join(auth2Config.DataDir, teleport.ComponentProcess), - auth.IdentityID{Role: types.RoleAdmin, HostUUID: auth2Config.HostUUID}) - require.NoError(t, err) - tlsConfig, err := identity.TLSConfig(nil) - require.NoError(t, err) - authAddrs := []utils.NetAddr{auth2.AuthAddr(t)} - clt, err := auth.NewClient(client.Config{ - Addrs: utils.NetAddrsToStrings(authAddrs), - Credentials: []client.Credentials{ - client.LoadTLS(tlsConfig), - }, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - }) - require.NoError(t, err) - return clt - } - testClient := func(clt *auth.Client) error { - _, err1 := clt.GetClusterName() - _, err2 := clt.GetClusterName() - return trace.NewAggregate(err1, err2) + testClient := func(t *testing.T) { + testAdminClient(t, auth2Config.DataDir, auth2.authAddrString(t)) } - clt := getAdminClient() - require.NoError(t, testClient(clt)) + testClient(t) // Phase 1: migrate auth1 to HSM - lb.RemoveBackend(auth1.AuthAddr(t)) + lb.RemoveBackend(auth1.authAddr(t)) auth1.process.Close() require.NoError(t, auth1.waitForShutdown(ctx)) - auth1Config.Auth.KeyStore = keystore.SetupSoftHSMTest(t) + auth1Config.Auth.KeyStore = keystore.HSMTestConfig(t) auth1 = newTeleportService(t, auth1Config, "auth1") - require.NoError(t, auth1.waitForStart(ctx)) + require.NoError(t, auth1.start(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + testClient(t) - authServices := TeleportServices{auth1, auth2} - teleportServices := TeleportServices{auth1, auth2, proxy} + authServices := teleportServices{auth1, auth2} + allServices := teleportServices{auth1, auth2, proxy} stages := []struct { targetPhase string @@ -807,34 +502,30 @@ func TestHSMMigrate(t *testing.T) { { targetPhase: types.RotationPhaseInit, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForPhaseChange(ctx)) + require.NoError(t, allServices.waitForPhaseChange(ctx)) require.NoError(t, authServices.waitForLocalAdditionalKeys(ctx)) - clt := getAdminClient() - require.NoError(t, testClient(clt)) + testClient(t) }, }, { targetPhase: types.RotationPhaseUpdateClients, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testClient(t) }, }, { targetPhase: types.RotationPhaseUpdateServers, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testClient(t) }, }, { targetPhase: types.RotationPhaseStandby, verify: func(t *testing.T) { - require.NoError(t, teleportServices.waitForRestart(ctx)) - clt = getAdminClient() - require.NoError(t, testClient(clt)) + require.NoError(t, allServices.waitForRestart(ctx)) + testClient(t) }, }, } @@ -851,21 +542,20 @@ func TestHSMMigrate(t *testing.T) { } // Safe to send traffic to new auth1 again - lb.AddBackend(auth1.AuthAddr(t)) + lb.AddBackend(auth1.authAddr(t)) // Phase 2: migrate auth2 to HSM - lb.RemoveBackend(auth2.AuthAddr(t)) + lb.RemoveBackend(auth2.authAddr(t)) auth2.process.Close() require.NoError(t, auth2.waitForShutdown(ctx)) - auth2Config.Auth.KeyStore = keystore.SetupSoftHSMTest(t) + auth2Config.Auth.KeyStore = keystore.HSMTestConfig(t) auth2 = newTeleportService(t, auth2Config, "auth2") - require.NoError(t, auth2.waitForStart(ctx)) + require.NoError(t, auth2.start(ctx)) - authServices = TeleportServices{auth1, auth2} - teleportServices = TeleportServices{auth1, auth2, proxy} + authServices = teleportServices{auth1, auth2} + allServices = teleportServices{auth1, auth2, proxy} - clt = getAdminClient() - require.NoError(t, testClient(clt)) + testClient(t) // do a full rotation for _, stage := range stages { @@ -879,6 +569,6 @@ func TestHSMMigrate(t *testing.T) { } // Safe to send traffic to new auth2 again - lb.AddBackend(auth2.AuthAddr(t)) - require.NoError(t, testClient(clt)) + lb.AddBackend(auth2.authAddr(t)) + testClient(t) } diff --git a/integration/hsm/reload_test.go b/integration/hsm/reload_test.go new file mode 100644 index 0000000000000..c838c17f1b980 --- /dev/null +++ b/integration/hsm/reload_test.go @@ -0,0 +1,81 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hsm + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/utils" +) + +const ( + totalReloads = 64 + concurrency = 8 +) + +// TestReloads starts up an Auth and Proxy process and repeatedly reloads both +// of them, asserting that the reload is always successful in a reasonable +// amount of time. This is meant to be a simplified test that should be able to +// catch flaky Teleport reload bugs that have been caught by the HSM tests in +// the past. +func TestReloads(t *testing.T) { + for i := 0; i < concurrency; i++ { + t.Run(fmt.Sprintf("%d", i), testReloads) + } +} + +func testReloads(t *testing.T) { + t.Parallel() + log := utils.NewLoggerForTests() + testCtx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + authConfig := newAuthConfig(t, log) + auth := newTeleportService(t, authConfig, "auth") + require.NoError(t, auth.start(testCtx)) + t.Cleanup(func() { require.NoError(t, auth.close()) }) + + proxyConfig := newProxyConfig(t, auth.authAddr(t), log) + proxy := newTeleportService(t, proxyConfig, "proxy") + require.NoError(t, proxy.start(testCtx)) + t.Cleanup(func() { require.NoError(t, proxy.close()) }) + + for i := 0; i < totalReloads/concurrency; i++ { + // Each reload event is broadcast in its own goroutine to try to make + // the reloads as simultaneous as possible, or at least introduce some + // randomness, to maximize the chance of catching errors. + go func() { + auth.process.BroadcastEvent(service.Event{Name: service.TeleportReloadEvent}) + }() + go func() { + proxy.process.BroadcastEvent(service.Event{Name: service.TeleportReloadEvent}) + }() + + require.NoError(t, withTimeout(testCtx, 30*time.Second, auth.waitForRestart), "attempt %d: waiting for auth restart", i+1) + require.NoError(t, withTimeout(testCtx, 30*time.Second, proxy.waitForRestart), "attempt %d: waiting for proxy restart", i+1) + } +} + +func withTimeout(ctx context.Context, d time.Duration, f func(context.Context) error) error { + ctx, cancel := context.WithTimeout(ctx, d) + defer cancel() + return f(ctx) +} diff --git a/lib/auth/keystore/keystore_test.go b/lib/auth/keystore/keystore_test.go index d99caebc9e2d6..bef8a73e81b28 100644 --- a/lib/auth/keystore/keystore_test.go +++ b/lib/auth/keystore/keystore_test.go @@ -23,12 +23,11 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509/pkix" - "log" - "os" "testing" "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -37,6 +36,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" ) var ( @@ -137,151 +137,15 @@ func TestKeyStore(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - skipSoftHSM := os.Getenv("SOFTHSM2_PATH") == "" - var softHSMConfig Config - if !skipSoftHSM { - softHSMConfig = SetupSoftHSMTest(t) - softHSMConfig.PKCS11.HostUUID = "server1" - } + message := []byte("Lorem ipsum dolor sit amet...") + messageHash := sha256.Sum256(message) - hostUUID := uuid.NewString() - - gcpKMSConfig := GCPKMSConfig{ - HostUUID: hostUUID, - ProtectionLevel: "HSM", - } - if keyRing := os.Getenv("TEST_GCP_KMS_KEYRING"); keyRing != "" { - t.Logf("Running test with real GCP KMS keyring %s", keyRing) - gcpKMSConfig.KeyRing = keyRing - } else { - t.Log("Running test with fake GCP KMS service") - _, dialer := newTestGCPKMSService(t) - testClient := newTestGCPKMSClient(t, dialer) - gcpKMSConfig.kmsClientOverride = testClient - gcpKMSConfig.KeyRing = "test-keyring" - } - - yubiSlotNumber := 0 - backends := []struct { - desc string - config Config - isSoftware bool - shouldSkip func() bool - // unusedRawKey should return passable raw key identifier for this - // backend that would not actually exist in the backend. - unusedRawKey func(t *testing.T) []byte - }{ - { - desc: "software", - config: Config{ - Software: SoftwareConfig{ - RSAKeyPairSource: native.GenerateKeyPair, - }, - }, - isSoftware: true, - shouldSkip: func() bool { return false }, - unusedRawKey: func(t *testing.T) []byte { - rawKey, _, err := native.GenerateKeyPair() - require.NoError(t, err) - return rawKey - }, - }, - { - desc: "softhsm", - config: softHSMConfig, - shouldSkip: func() bool { - if skipSoftHSM { - log.Println("Skipping softhsm test because SOFTHSM2_PATH is not set.") - return true - } - return false - }, - unusedRawKey: func(t *testing.T) []byte { - rawKey, err := keyID{ - HostID: softHSMConfig.PKCS11.HostUUID, - KeyID: "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF", - }.marshal() - require.NoError(t, err) - return rawKey - }, - }, - { - desc: "yubihsm", - config: Config{ - PKCS11: PKCS11Config{ - Path: os.Getenv("YUBIHSM_PKCS11_PATH"), - SlotNumber: &yubiSlotNumber, - Pin: "0001password", - HostUUID: hostUUID, - }, - }, - shouldSkip: func() bool { - if os.Getenv("YUBIHSM_PKCS11_CONF") == "" || os.Getenv("YUBIHSM_PKCS11_PATH") == "" { - log.Println("Skipping yubihsm test because YUBIHSM_PKCS11_CONF or YUBIHSM_PKCS11_PATH is not set.") - return true - } - return false - }, - unusedRawKey: func(t *testing.T) []byte { - rawKey, err := keyID{ - HostID: hostUUID, - KeyID: "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF", - }.marshal() - require.NoError(t, err) - return rawKey - }, - }, - { - desc: "cloudhsm", - config: Config{ - PKCS11: PKCS11Config{ - Path: "/opt/cloudhsm/lib/libcloudhsm_pkcs11.so", - TokenLabel: "cavium", - Pin: os.Getenv("CLOUDHSM_PIN"), - HostUUID: hostUUID, - }, - }, - shouldSkip: func() bool { - if os.Getenv("CLOUDHSM_PIN") == "" { - log.Println("Skipping cloudhsm test because CLOUDHSM_PIN is not set.") - return true - } - return false - }, - unusedRawKey: func(t *testing.T) []byte { - rawKey, err := keyID{ - HostID: hostUUID, - KeyID: "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF", - }.marshal() - require.NoError(t, err) - return rawKey - }, - }, - { - desc: "gcp kms", - config: Config{ - GCPKMS: gcpKMSConfig, - }, - shouldSkip: func() bool { - return false - }, - unusedRawKey: func(t *testing.T) []byte { - return gcpKMSKeyID{ - keyVersionName: gcpKMSConfig.KeyRing + "/cryptoKeys/FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF" + keyVersionSuffix, - }.marshal() - }, - }, - } - - for _, tc := range backends { - tc := tc - t.Run(tc.desc, func(t *testing.T) { - if tc.shouldSkip() { - t.SkipNow() - } + pack := newTestPack(ctx, t) + for _, backendDesc := range pack.backends { + t.Run(backendDesc.name, func(t *testing.T) { // create the keystore manager - keyStore, err := NewManager(ctx, tc.config) + keyStore, err := NewManager(ctx, backendDesc.config) require.NoError(t, err) // create a key @@ -299,13 +163,11 @@ func TestKeyStore(t *testing.T) { require.NotNil(t, signer) // try signing something - message := []byte("Lorem ipsum dolor sit amet...") - hashed := sha256.Sum256(message) - signature, err := signer.Sign(rand.Reader, hashed[:], crypto.SHA256) + signature, err := signer.Sign(rand.Reader, messageHash[:], crypto.SHA256) require.NoError(t, err) require.NotEmpty(t, signature) // make sure we can verify the signature with a "known good" rsa implementation - err = rsa.VerifyPKCS1v15(signer.Public().(*rsa.PublicKey), crypto.SHA256, hashed[:], signature) + err = rsa.VerifyPKCS1v15(signer.Public().(*rsa.PublicKey), crypto.SHA256, messageHash[:], signature) require.NoError(t, err) // make sure we can get the ssh public key @@ -389,7 +251,7 @@ func TestKeyStore(t *testing.T) { }) require.NoError(t, err) - if !tc.isSoftware { + if backendDesc.expectedKeyType != types.PrivateKeyType_RAW { // hsm keyStore should not get any signer from raw keys _, err = keyStore.GetSSHSigner(ctx, ca) require.True(t, trace.IsNotFound(err)) @@ -417,19 +279,16 @@ func TestKeyStore(t *testing.T) { }) } - for _, tc := range backends { - t.Run(tc.desc+"_DeleteUnusedKeys", func(t *testing.T) { - if tc.shouldSkip() { - t.SkipNow() - } - if tc.isSoftware { + for _, backendDesc := range pack.backends { + t.Run(backendDesc.name+"_DeleteUnusedKeys", func(t *testing.T) { + if backendDesc.expectedKeyType == types.PrivateKeyType_RAW { // deleting keys is a no-op for software, we won't get the error // we're expecting t.SkipNow() } // create the keystore manager - keyStore, err := NewManager(ctx, tc.config) + keyStore, err := NewManager(ctx, backendDesc.config) require.NoError(t, err) // create some keys to test DeleteUnusedKeys @@ -460,7 +319,7 @@ func TestKeyStore(t *testing.T) { // Make sure key deletion is aborted when one of the active keys // cannot be found. This makes sure that we don't accidentally // delete current active keys in case the ListKeys operation fails. - fakeActiveKey := tc.unusedRawKey(t) + fakeActiveKey := backendDesc.unusedRawKey err = keyStore.DeleteUnusedKeys(ctx, [][]byte{fakeActiveKey}) require.True(t, trace.IsNotFound(err), "expected NotFound error, got %v", err) @@ -470,3 +329,123 @@ func TestKeyStore(t *testing.T) { }) } } + +type testPack struct { + backends []*backendDesc + clock clockwork.FakeClock +} + +type backendDesc struct { + name string + config Config + backend backend + expectedKeyType types.PrivateKeyType + unusedRawKey []byte + deletionDoesNothing bool +} + +func newTestPack(ctx context.Context, t *testing.T) *testPack { + clock := clockwork.NewFakeClock() + var backends []*backendDesc + + hostUUID := uuid.NewString() + logger := utils.NewLoggerForTests() + + unusedPKCS11Key, err := keyID{ + HostID: hostUUID, + KeyID: "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF", + }.marshal() + require.NoError(t, err) + + softwareConfig := Config{Software: SoftwareConfig{ + RSAKeyPairSource: native.GenerateKeyPair, + }} + softwareBackend := newSoftwareKeyStore(&softwareConfig.Software, logger) + backends = append(backends, &backendDesc{ + name: "software", + config: softwareConfig, + backend: softwareBackend, + unusedRawKey: testRawPrivateKey, + deletionDoesNothing: true, + }) + + if config, ok := softHSMTestConfig(t); ok { + config.PKCS11.HostUUID = hostUUID + backend, err := newPKCS11KeyStore(&config.PKCS11, logger) + require.NoError(t, err) + backends = append(backends, &backendDesc{ + name: "softhsm", + config: config, + backend: backend, + expectedKeyType: types.PrivateKeyType_PKCS11, + unusedRawKey: unusedPKCS11Key, + }) + } + + if config, ok := yubiHSMTestConfig(t); ok { + config.PKCS11.HostUUID = hostUUID + backend, err := newPKCS11KeyStore(&config.PKCS11, logger) + require.NoError(t, err) + backends = append(backends, &backendDesc{ + name: "yubihsm", + config: config, + backend: backend, + expectedKeyType: types.PrivateKeyType_PKCS11, + unusedRawKey: unusedPKCS11Key, + }) + } + + if config, ok := cloudHSMTestConfig(t); ok { + config.PKCS11.HostUUID = hostUUID + backend, err := newPKCS11KeyStore(&config.PKCS11, logger) + require.NoError(t, err) + backends = append(backends, &backendDesc{ + name: "yubihsm", + config: config, + backend: backend, + expectedKeyType: types.PrivateKeyType_PKCS11, + unusedRawKey: unusedPKCS11Key, + }) + } + + if config, ok := gcpKMSTestConfig(t); ok { + config.GCPKMS.HostUUID = hostUUID + backend, err := newGCPKMSKeyStore(ctx, &config.GCPKMS, logger) + require.NoError(t, err) + backends = append(backends, &backendDesc{ + name: "gcp_kms", + config: config, + backend: backend, + expectedKeyType: types.PrivateKeyType_GCP_KMS, + unusedRawKey: gcpKMSKeyID{ + keyVersionName: config.GCPKMS.KeyRing + "/cryptoKeys/FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF" + keyVersionSuffix, + }.marshal(), + }) + } + _, gcpKMSDialer := newTestGCPKMSService(t) + testGCPKMSClient := newTestGCPKMSClient(t, gcpKMSDialer) + fakeGCPKMSConfig := Config{ + GCPKMS: GCPKMSConfig{ + HostUUID: hostUUID, + ProtectionLevel: "HSM", + KeyRing: "test-keyring", + kmsClientOverride: testGCPKMSClient, + }, + } + fakeGCPKMSBackend, err := newGCPKMSKeyStore(ctx, &fakeGCPKMSConfig.GCPKMS, logger) + require.NoError(t, err) + backends = append(backends, &backendDesc{ + name: "fake_gcp_kms", + config: fakeGCPKMSConfig, + backend: fakeGCPKMSBackend, + expectedKeyType: types.PrivateKeyType_GCP_KMS, + unusedRawKey: gcpKMSKeyID{ + keyVersionName: "test-keyring/cryptoKeys/FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF" + keyVersionSuffix, + }.marshal(), + }) + + return &testPack{ + backends: backends, + clock: clock, + } +} diff --git a/lib/auth/keystore/pkcs11.go b/lib/auth/keystore/pkcs11.go index 2f012e6f1ab45..ba4f0cdbd1f84 100644 --- a/lib/auth/keystore/pkcs11.go +++ b/lib/auth/keystore/pkcs11.go @@ -117,7 +117,7 @@ func (p *pkcs11KeyStore) findUnusedID() (keyID, error) { // https://developers.yubico.com/YubiHSM2/Concepts/Object_ID.html for id := uint16(1); id < 0xffff; id++ { idBytes := []byte{byte((id >> 8) & 0xff), byte(id & 0xff)} - existingSigner, err := p.ctx.FindKeyPair(idBytes, []byte(p.hostUUID)) + existingSigner, err := p.ctx.FindKeyPair(idBytes, nil /*label*/) // FindKeyPair is expected to return nil, nil if the id is not found, // any error is unexpected. if err != nil { diff --git a/lib/auth/keystore/testhelpers.go b/lib/auth/keystore/testhelpers.go index b6165f2a3a0eb..f150464d7416e 100644 --- a/lib/auth/keystore/testhelpers.go +++ b/lib/auth/keystore/testhelpers.go @@ -28,14 +28,78 @@ import ( "github.com/stretchr/testify/require" ) +func HSMTestConfig(t *testing.T) Config { + if cfg, ok := yubiHSMTestConfig(t); ok { + t.Log("Running test with YubiHSM") + return cfg + } + if cfg, ok := cloudHSMTestConfig(t); ok { + t.Log("Running test with AWS CloudHSM") + return cfg + } + if cfg, ok := gcpKMSTestConfig(t); ok { + t.Log("Running test with GCP KMS") + return cfg + } + if cfg, ok := softHSMTestConfig(t); ok { + t.Log("Running test with SoftHSM") + return cfg + } + t.Skip("No HSM available for test") + return Config{} +} + +func yubiHSMTestConfig(t *testing.T) (Config, bool) { + yubiHSMPath := os.Getenv("TELEPORT_TEST_YUBIHSM_PKCS11_PATH") + yubiHSMPin := os.Getenv("TELEPORT_TEST_YUBIHSM_PIN") + if yubiHSMPath == "" || yubiHSMPin == "" { + return Config{}, false + } + slotNumber := 0 + return Config{ + PKCS11: PKCS11Config{ + Path: yubiHSMPath, + SlotNumber: &slotNumber, + Pin: yubiHSMPin, + }, + }, true +} + +func cloudHSMTestConfig(t *testing.T) (Config, bool) { + cloudHSMPin := os.Getenv("TELEPORT_TEST_CLOUDHSM_PIN") + if cloudHSMPin == "" { + return Config{}, false + } + return Config{ + PKCS11: PKCS11Config{ + Path: "/opt/cloudhsm/lib/libcloudhsm_pkcs11.so", + TokenLabel: "cavium", + Pin: cloudHSMPin, + }, + }, true +} + +func gcpKMSTestConfig(t *testing.T) (Config, bool) { + gcpKeyring := os.Getenv("TELEPORT_TEST_GCP_KMS_KEYRING") + if gcpKeyring == "" { + return Config{}, false + } + return Config{ + GCPKMS: GCPKMSConfig{ + KeyRing: gcpKeyring, + ProtectionLevel: "SOFTWARE", + }, + }, true +} + var ( - cachedConfig *Config - cacheMutex sync.Mutex + cachedSoftHSMConfig *Config + cachedSoftHSMConfigMutex sync.Mutex ) -// SetupSoftHSMToken is for use in tests only and creates a test SOFTHSM2 -// token. This should be used for all tests which need to use SoftHSM because -// the library can only be initialized once and SOFTHSM2_PATH and SOFTHSM2_CONF +// softHSMTestConfig is for use in tests only and creates a test SOFTHSM2 token. +// This should be used for all tests which need to use SoftHSM because the +// library can only be initialized once and SOFTHSM2_PATH and SOFTHSM2_CONF // cannot be changed. New tokens added after the library has been initialized // will not be found by the library. // @@ -47,15 +111,17 @@ var ( // delete the token or the entire token directory. Each test should clean up // all keys that it creates because SoftHSM2 gets really slow when there are // many keys for a given token. -func SetupSoftHSMTest(t *testing.T) Config { +func softHSMTestConfig(t *testing.T) (Config, bool) { path := os.Getenv("SOFTHSM2_PATH") - require.NotEqual(t, path, "") + if path == "" { + return Config{}, false + } - cacheMutex.Lock() - defer cacheMutex.Unlock() + cachedSoftHSMConfigMutex.Lock() + defer cachedSoftHSMConfigMutex.Unlock() - if cachedConfig != nil { - return *cachedConfig + if cachedSoftHSMConfig != nil { + return *cachedSoftHSMConfig, true } if os.Getenv("SOFTHSM2_CONF") == "" { @@ -89,12 +155,12 @@ func SetupSoftHSMTest(t *testing.T) Config { require.NoError(t, err, "error attempting to run softhsm2-util") } - cachedConfig = &Config{ + cachedSoftHSMConfig = &Config{ PKCS11: PKCS11Config{ Path: path, TokenLabel: tokenLabel, Pin: "password", }, } - return *cachedConfig + return *cachedSoftHSMConfig, true }