From a49ca39db6969a593533f6cfb1e88b019d7b7b69 Mon Sep 17 00:00:00 2001 From: David Boslee Date: Tue, 27 May 2025 18:12:27 -0400 Subject: [PATCH 1/2] keystore: add support for aws kms multi-region key replication (#53927) * keystore: add support for aws kms multi-region key replication * update func name ApplyConfig -> ApplyMultiRegionConfig * more descriptive var out -> describeKeyOut * fix typo * better var names * add comment * renaming vars for readability * refactor multi-region auth config * add comment about cert authority lock * fix typo * move funcs up * update comment * copy whole struct instead of individual values --- lib/auth/init.go | 58 ++++- lib/auth/keystore/aws_kms.go | 325 +++++++++++++++++++++++------ lib/auth/keystore/aws_kms_test.go | 274 +++++++++++++++++++++++- lib/auth/keystore/keystore_test.go | 6 +- lib/auth/keystore/manager.go | 23 +- lib/config/configuration.go | 3 +- lib/config/fileconf.go | 5 +- lib/service/servicecfg/auth.go | 17 +- 8 files changed, 624 insertions(+), 87 deletions(-) diff --git a/lib/auth/init.go b/lib/auth/init.go index e542876781ffe..863cb52d082b4 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -787,7 +787,10 @@ func initializeAuthority(ctx context.Context, asrv *Server, caID types.CertAuthI " this cluster and then perform a CA rotation: https://goteleport.com/docs/admin-guides/management/operations/ca-rotation/", caID.Type, strings.Join(allKeyTypes[:numKeyTypes-1], ", "), allKeyTypes[numKeyTypes-1]) } - + ca, err = applyAuthorityConfig(ctx, asrv, ca) + if err != nil { + return nil, nil, trace.Wrap(err) + } keysInUse := collectKeysInUse(ca.GetActiveKeys(), ca.GetAdditionalTrustedKeys()) return usableKeysResult, keysInUse, nil } @@ -807,6 +810,59 @@ func collectKeysInUse(cas ...types.CAKeySet) (keysInUse [][]byte) { return keysInUse } +// applyAuthorityConfig applies the latest keystore config to active keys updating +// the stored CA if any changes occur. +func applyAuthorityConfig(ctx context.Context, asrv *Server, ca types.CertAuthority) (types.CertAuthority, error) { + activeKeys := ca.GetActiveKeys() + var ( + changed bool + err error + ) + + apply := func(curr []byte) ([]byte, error) { + next, err := asrv.keyStore.ApplyMultiRegionConfig(ctx, curr) + if err != nil { + return nil, trace.Wrap(err) + } + if !slices.Equal(curr, next) { + changed = true + } + return next, nil + } + + for _, key := range activeKeys.SSH { + key.PrivateKey, err = apply(key.PrivateKey) + if err != nil { + return nil, trace.Wrap(err) + } + } + for _, key := range activeKeys.TLS { + key.Key, err = apply(key.Key) + if err != nil { + return nil, trace.Wrap(err) + } + } + for _, key := range activeKeys.JWT { + key.PrivateKey, err = apply(key.PrivateKey) + if err != nil { + return nil, trace.Wrap(err) + } + } + if !changed { + return ca, nil + } + if err := ca.SetActiveKeys(activeKeys); err != nil { + return nil, trace.Wrap(err) + } + // This is only executed during cluster init while holding a lock to prevent + // other auth servers from updating CAs simulaniously. + ca, err = asrv.UpdateCertAuthority(ctx, ca) + if err != nil { + return nil, trace.Wrap(err) + } + return ca, nil +} + // generateAuthority creates a new self-signed authority of the provided type // and returns it to the caller. It is the responsibility of callers to persist // the authority. diff --git a/lib/auth/keystore/aws_kms.go b/lib/auth/keystore/aws_kms.go index 28915c62bf5ca..b5be361a1d0e0 100644 --- a/lib/auth/keystore/aws_kms.go +++ b/lib/auth/keystore/aws_kms.go @@ -54,14 +54,19 @@ const ( pendingKeyBaseRetryInterval = time.Second / 2 pendingKeyMaxRetryInterval = 4 * time.Second - pendingKeyTimeout = 30 * time.Second + // TODO(dboslee): waiting on AWS support to answer question regarding + // long time for GetPublicKey to succeed after updating key via UpdatePrimaryRegion. + pendingKeyTimeout = 120 * time.Second ) type awsKMSKeystore struct { kms kmsClient + mrk mrkClient awsAccount string awsRegion string multiRegionEnabled bool + primaryRegion string + replicaRegions map[string]struct{} tags map[string]string clock clockwork.Clock logger *slog.Logger @@ -69,6 +74,8 @@ type awsKMSKeystore struct { func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts *Options) (*awsKMSKeystore, error) { stsClient, kmsClient := opts.awsSTSClient, opts.awsKMSClient + mrkClient := opts.mrkClient + if stsClient == nil || kmsClient == nil { useFIPSEndpoint := aws.FIPSEndpointStateUnset if opts.FIPS { @@ -84,8 +91,14 @@ func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts * if stsClient == nil { stsClient = stsutils.NewFromConfig(awsCfg) } - if kmsClient == nil { - kmsClient = kms.NewFromConfig(awsCfg) + if kmsClient == nil || mrkClient == nil { + realKMS := kms.NewFromConfig(awsCfg) + if kmsClient == nil { + kmsClient = realKMS + } + if mrkClient == nil { + mrkClient = realKMS + } } } id, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) @@ -109,12 +122,24 @@ func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts * if clock == nil { clock = clockwork.NewRealClock() } + primary := cfg.MultiRegion.PrimaryRegion + if primary == "" { + primary = cfg.AWSRegion + } + replicas := make(map[string]struct{}) + for _, region := range append(cfg.MultiRegion.ReplicaRegions, primary, cfg.AWSRegion) { + replicas[region] = struct{}{} + } + return &awsKMSKeystore{ awsAccount: cfg.AWSAccount, awsRegion: cfg.AWSRegion, tags: tags, multiRegionEnabled: cfg.MultiRegion.Enabled, + primaryRegion: primary, + replicaRegions: replicas, kms: kmsClient, + mrk: mrkClient, clock: clock, logger: opts.Logger, }, nil @@ -164,15 +189,18 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites return nil, nil, trace.Errorf("KeyMetadata of generated key is nil") } keyARN := aws.ToString(output.KeyMetadata.Arn) - signer, err := a.newSigner(ctx, keyARN) + key, err := keyIDFromArn(keyARN) + if err != nil { + return nil, nil, trace.Wrap(err) + } + keyID, err := a.applyMRKConfig(ctx, key) + if err != nil { + return nil, nil, trace.Wrap(err) + } + signer, err := a.newSigner(ctx, key) if err != nil { return nil, nil, trace.Wrap(err) } - keyID := awsKMSKeyID{ - arn: keyARN, - account: a.awsAccount, - region: a.awsRegion, - }.marshal() return keyID, signer, nil } @@ -188,21 +216,33 @@ func awsAlgorithm(alg cryptosuites.Algorithm) (kmstypes.KeySpec, error) { // getSigner returns a crypto.Signer for the given key identifier, if it is found. func (a *awsKMSKeystore) getSigner(ctx context.Context, rawKey []byte, publicKey crypto.PublicKey) (crypto.Signer, error) { - keyID, err := parseAWSKMSKeyID(rawKey) + key, err := parseAWSKMSKeyID(rawKey) if err != nil { return nil, trace.Wrap(err) } - return a.newSignerWithPublicKey(ctx, keyID.arn, publicKey) + return a.newSignerWithPublicKey(ctx, key, publicKey) } type awsKMSSigner struct { - keyARN string - pub crypto.PublicKey - kms kmsClient + key awsKMSKeyID + pub crypto.PublicKey + kms kmsClient } -func (a *awsKMSKeystore) newSigner(ctx context.Context, keyARN string) (*awsKMSSigner, error) { - pubkeyDER, err := a.getPublicKeyDER(ctx, keyARN) +func (a *awsKMSKeystore) newSigner(ctx context.Context, key awsKMSKeyID) (*awsKMSSigner, error) { + var pubkeyDER []byte + err := a.retryOnConsistencyError(ctx, func(ctx context.Context) error { + a.logger.DebugContext(ctx, "Fetching public key", "key_arn", key.arn) + output, err := a.kms.GetPublicKey(ctx, &kms.GetPublicKeyInput{ + KeyId: aws.String(key.id), + }) + if err != nil { + a.logger.DebugContext(ctx, "Failed to fetch public key", "key_arn", key.arn, "err", err) + return trace.Wrap(err, "fetching public key") + } + pubkeyDER = output.PublicKey + return nil + }) if err != nil { return nil, trace.Wrap(err) } @@ -211,12 +251,13 @@ func (a *awsKMSKeystore) newSigner(ctx context.Context, keyARN string) (*awsKMSS if err != nil { return nil, trace.Wrap(err, "unexpected error parsing public key der") } - return a.newSignerWithPublicKey(ctx, keyARN, pub) + return a.newSignerWithPublicKey(ctx, key, pub) } -func (a *awsKMSKeystore) getPublicKeyDER(ctx context.Context, keyARN string) ([]byte, error) { - // KMS is eventually-consistent, and this is called immediately after the - // key has been recreated, so a few retries may be necessary. +// retryOnConsistencyError handles retrying KMS key operations that may fail +// temporarily due to eventual consistency. +// https://docs.aws.amazon.com/kms/latest/developerguide/programming-eventual-consistency.html +func (a *awsKMSKeystore) retryOnConsistencyError(ctx context.Context, fn func(ctx context.Context) error) error { retry, err := retryutils.NewRetryV2(retryutils.RetryV2Config{ First: pendingKeyBaseRetryInterval, Driver: retryutils.NewExponentialDriver(pendingKeyBaseRetryInterval), @@ -225,49 +266,41 @@ func (a *awsKMSKeystore) getPublicKeyDER(ctx context.Context, keyARN string) ([] Clock: a.clock, }) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } ctx, cancel := context.WithTimeout(ctx, pendingKeyTimeout) defer cancel() timeout := a.clock.NewTimer(pendingKeyTimeout) defer timeout.Stop() for { - output, err := a.kms.GetPublicKey(ctx, &kms.GetPublicKeyInput{ - KeyId: aws.String(keyARN), - }) + err := fn(ctx) if err == nil { - return output.PublicKey, nil + return nil } - - // Check if the error is one of the two expected eventual consistency - // error types - // https://docs.aws.amazon.com/kms/latest/developerguide/programming-eventual-consistency.html var ( notFound *kmstypes.NotFoundException invalidState *kmstypes.KMSInvalidStateException ) if !errors.As(err, ¬Found) && !errors.As(err, &invalidState) { - return nil, trace.Wrap(err, "unexpected error fetching AWS KMS public key") + return trace.Wrap(err, "unexpected error") } - startedWaiting := a.clock.Now() select { - case t := <-retry.After(): - a.logger.DebugContext(ctx, "Failed to fetch public key, retrying", "key_arn", keyARN, "retry_interval", t.Sub(startedWaiting)) + case <-retry.After(): retry.Inc() case <-ctx.Done(): - return nil, trace.Wrap(ctx.Err()) + return trace.Wrap(ctx.Err()) case <-timeout.Chan(): - return nil, trace.Errorf("timed out waiting for AWS KMS public key") + return trace.Wrap(err, "timeout retrying eventual consistency errors") } } } -func (a *awsKMSKeystore) newSignerWithPublicKey(ctx context.Context, keyARN string, publicKey crypto.PublicKey) (*awsKMSSigner, error) { +func (a *awsKMSKeystore) newSignerWithPublicKey(_ context.Context, key awsKMSKeyID, publicKey crypto.PublicKey) (*awsKMSSigner, error) { return &awsKMSSigner{ - keyARN: keyARN, - pub: publicKey, - kms: a.kms, + key: key, + pub: publicKey, + kms: a.kms, }, nil } @@ -302,7 +335,7 @@ func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpt return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key", opts.HashFunc()) } output, err := a.kms.Sign(context.TODO(), &kms.SignInput{ - KeyId: aws.String(a.keyARN), + KeyId: aws.String(a.key.id), Message: digest, MessageType: kmstypes.MessageTypeDigest, SigningAlgorithm: signingAlg, @@ -332,11 +365,11 @@ func (a *awsKMSKeystore) canSignWithKey(ctx context.Context, raw []byte, keyType if keyType != types.PrivateKeyType_AWS_KMS { return false, nil } - keyID, err := parseAWSKMSKeyID(raw) + key, err := parseAWSKMSKeyID(raw) if err != nil { return false, trace.Wrap(err) } - return keyID.account == a.awsAccount && keyID.region == a.awsRegion, nil + return key.account == a.awsAccount && (key.region == a.awsRegion || key.isMRK()), nil } // DeleteUnusedKeys deletes all keys readable from the AWS KMS account and @@ -372,32 +405,36 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by // calls parseAWSKMSKeyID. return trace.Wrap(err) } - activeAWSKMSKeys[keyID.arn] = 0 + activeAWSKMSKeys[keyID.id] = 0 } var keysToDelete []string var mu sync.RWMutex - err := a.forEachKey(ctx, func(ctx context.Context, keyARN string) error { + err := a.forEachKey(ctx, func(ctx context.Context, arn string) error { + key, err := keyIDFromArn(arn) + if err != nil { + return trace.Wrap(err) + } mu.RLock() - _, active := activeAWSKMSKeys[keyARN] + _, active := activeAWSKMSKeys[key.id] mu.RUnlock() if active { // This is a known active key, record that it was found and return // (since it should never be deleted). mu.Lock() defer mu.Unlock() - activeAWSKMSKeys[keyARN] += 1 + activeAWSKMSKeys[key.id] += 1 return nil } // Check if this key was created by this Teleport cluster. output, err := a.kms.ListResourceTags(ctx, &kms.ListResourceTagsInput{ - KeyId: aws.String(keyARN), + KeyId: aws.String(key.id), }) if err != nil { // It's entirely expected that we won't be allowed to fetch // tags for some keys, don't worry about deleting those. - a.logger.DebugContext(ctx, "failed to fetch tags for AWS KMS key, skipping", "key_arn", keyARN, "error", err) + a.logger.DebugContext(ctx, "failed to fetch tags for AWS KMS key, skipping", "key_arn", arn, "error", err) return nil } @@ -412,17 +449,17 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by // Check if this key is not enabled or was created in the past 5 minutes. describeOutput, err := a.kms.DescribeKey(ctx, &kms.DescribeKeyInput{ - KeyId: aws.String(keyARN), + KeyId: aws.String(key.id), }) if err != nil { - return trace.Wrap(err, "failed to describe AWS KMS key %q", keyARN) + return trace.Wrap(err, "failed to describe AWS KMS key %q", arn) } if describeOutput.KeyMetadata == nil { - return trace.Errorf("failed to describe AWS KMS key %q", keyARN) + return trace.Errorf("failed to describe AWS KMS key %q", arn) } if keyState := describeOutput.KeyMetadata.KeyState; keyState != kmstypes.KeyStateEnabled { a.logger.InfoContext(ctx, "deleteUnusedKeys skipping AWS KMS key which is not in enabled state.", - "key_arn", keyARN, "key_state", keyState) + "key_arn", arn, "key_state", keyState) return nil } creationDate := aws.ToTime(describeOutput.KeyMetadata.CreationDate) @@ -431,13 +468,13 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by // created by a different auth server and just haven't been added to // the backend CA yet (which is why they don't appear in activeKeys). a.logger.InfoContext(ctx, "deleteUnusedKeys skipping AWS KMS key which was created in the past 5 minutes.", - "key_arn", keyARN) + "key_arn", arn) return nil } mu.Lock() defer mu.Unlock() - keysToDelete = append(keysToDelete, keyARN) + keysToDelete = append(keysToDelete, *describeOutput.KeyMetadata.Arn) return nil }) if err != nil { @@ -488,39 +525,196 @@ func (a *awsKMSKeystore) forEachKey(ctx context.Context, fn func(ctx context.Con marker = aws.ToString(output.NextMarker) more = output.Truncated for _, keyEntry := range output.Keys { - keyArn := aws.ToString(keyEntry.KeyArn) + keyID := aws.ToString(keyEntry.KeyArn) errGroup.Go(func() error { - return trace.Wrap(fn(ctx, keyArn)) + return trace.Wrap(fn(ctx, keyID)) }) } } return trace.Wrap(errGroup.Wait()) } +func (a *awsKMSKeystore) applyMultiRegionConfig(ctx context.Context, keyID []byte) ([]byte, error) { + if keyType(keyID) != types.PrivateKeyType_AWS_KMS { + return keyID, nil + } + key, err := parseAWSKMSKeyID(keyID) + if err != nil { + return nil, trace.Wrap(err) + } + keyID, err = a.applyMRKConfig(ctx, key) + if err != nil { + return nil, trace.Wrap(err) + } + return keyID, nil +} + +func (a *awsKMSKeystore) applyMRKConfig(ctx context.Context, key awsKMSKeyID) ([]byte, error) { + if !key.isMRK() { + if a.multiRegionEnabled { + a.logger.WarnContext(ctx, "Unable to replicate single-region key. A CA rotation is required to migrate to a multi-region key.", "key_arn", key.arn) + } + return key.marshal(), nil + } + + tags := make([]kmstypes.Tag, 0, len(a.tags)) + for k, v := range a.tags { + tags = append(tags, kmstypes.Tag{ + TagKey: aws.String(k), + TagValue: aws.String(v), + }) + } + + client := a.mrk + describeKeyOut, err := client.DescribeKey(ctx, &kms.DescribeKeyInput{ + KeyId: aws.String(key.id), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + currRegionKey, err := keyIDFromArn(*describeKeyOut.KeyMetadata.Arn) + if err != nil { + return nil, trace.Wrap(err) + } + if err := a.waitForKeyEnabled(ctx, client, currRegionKey); err != nil { + return nil, trace.Wrap(err) + } + if describeKeyOut.KeyMetadata.MultiRegionConfiguration == nil { + // This error is not expected to be reached since we check that the key + // is a multi-region key above. + return nil, trace.Errorf("kms key %s missing multi-region configuration", currRegionKey.arn) + } + + currPrimaryKey, err := keyIDFromArn(*describeKeyOut.KeyMetadata.MultiRegionConfiguration.PrimaryKey.Arn) + if err != nil { + return nil, trace.Wrap(err) + } + var existingReplicas []awsKMSKeyID + for _, replica := range append( + describeKeyOut.KeyMetadata.MultiRegionConfiguration.ReplicaKeys, + *describeKeyOut.KeyMetadata.MultiRegionConfiguration.PrimaryKey, + ) { + key, err := keyIDFromArn(*replica.Arn) + if err != nil { + return nil, trace.Wrap(err) + } + existingReplicas = append(existingReplicas, key) + } + + // Only the primary region can replicate keys and update the primary region + // so return early if we are operating outside of the primary region. + if currRegionKey.region != currPrimaryKey.region { + return key.marshal(), nil + } + + for region := range a.replicaRegions { + // Check if a replica already exists in this region. + if slices.ContainsFunc(existingReplicas, func(key awsKMSKeyID) bool { + return key.region == region + }) { + continue + } + a.logger.DebugContext(ctx, "Replicating key", "kms_arn", currPrimaryKey.arn, "replica_region", region) + out, err := client.ReplicateKey(ctx, &kms.ReplicateKeyInput{ + KeyId: &key.id, + ReplicaRegion: ®ion, + Tags: tags, + }) + if err != nil { + return nil, trace.Wrap(err) + } + key, err := keyIDFromArn(*out.ReplicaKeyMetadata.Arn) + if err != nil { + return nil, trace.Wrap(err) + } + existingReplicas = append(existingReplicas, key) + } + if currPrimaryKey.region == a.primaryRegion { + return currPrimaryKey.marshal(), nil + } + + err = a.retryOnConsistencyError(ctx, func(ctx context.Context) error { + a.logger.DebugContext(ctx, "Updating primary region", "kms_arn", currPrimaryKey.arn, "primary", a.primaryRegion) + _, err := client.UpdatePrimaryRegion(ctx, &kms.UpdatePrimaryRegionInput{ + KeyId: aws.String(currPrimaryKey.id), + PrimaryRegion: aws.String(a.primaryRegion), + }) + if err != nil { + return trace.Wrap(err) + } + return nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, key := range existingReplicas { + if key.region == a.primaryRegion { + return key.marshal(), nil + } + } + return nil, trace.Errorf("failed to find updated primary key region=%s key_id=%s", a.primaryRegion, key.id) +} + +func (a *awsKMSKeystore) waitForKeyEnabled(ctx context.Context, client mrkClient, key awsKMSKeyID) error { + err := a.retryOnConsistencyError(ctx, func(ctx context.Context) error { + a.logger.DebugContext(ctx, "Waiting for key to be enabled", "key_arn", key.arn) + out, err := client.DescribeKey(ctx, &kms.DescribeKeyInput{ + KeyId: aws.String(key.id), + }) + if err != nil { + a.logger.DebugContext(ctx, "Failed to get key state", "key_arn", key.arn, "err", err) + return trace.Wrap(err, "failed to get key state") + } + // Return a KMSInvalidStateException so this can be retired by + // retryOnConsistencyError. + if out.KeyMetadata.KeyState != kmstypes.KeyStateEnabled { + return &kmstypes.KMSInvalidStateException{ + Message: aws.String("key is not enabled state=" + string(out.KeyMetadata.KeyState)), + } + } + return nil + }) + return trace.Wrap(err) +} + type awsKMSKeyID struct { - arn, account, region string + id, arn, account, region string } func (a awsKMSKeyID) marshal() []byte { return []byte(awskmsPrefix + a.arn) } -func parseAWSKMSKeyID(raw []byte) (awsKMSKeyID, error) { - if keyType(raw) != types.PrivateKeyType_AWS_KMS { - return awsKMSKeyID{}, trace.BadParameter("unable to parse invalid AWS KMS key") - } - keyARN := strings.TrimPrefix(string(raw), awskmsPrefix) +// isMRK checks if a key is a multi-region key. +func (a awsKMSKeyID) isMRK() bool { + return strings.HasPrefix(a.id, "mrk-") +} + +func keyIDFromArn(keyARN string) (awsKMSKeyID, error) { parsedARN, err := arn.Parse(keyARN) if err != nil { return awsKMSKeyID{}, trace.Wrap(err, "unable parse ARN of AWS KMS key") } + id := strings.TrimPrefix(parsedARN.Resource, "key/") return awsKMSKeyID{ + id: id, arn: keyARN, account: parsedARN.AccountID, region: parsedARN.Region, }, nil } +func parseAWSKMSKeyID(raw []byte) (awsKMSKeyID, error) { + if keyType(raw) != types.PrivateKeyType_AWS_KMS { + return awsKMSKeyID{}, trace.BadParameter("unable to parse invalid AWS KMS key") + } + keyARN := strings.TrimPrefix(string(raw), awskmsPrefix) + key, err := keyIDFromArn(keyARN) + return key, trace.Wrap(err) +} + type kmsClient interface { CreateKey(context.Context, *kms.CreateKeyInput, ...func(*kms.Options)) (*kms.CreateKeyOutput, error) GetPublicKey(context.Context, *kms.GetPublicKeyInput, ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) @@ -531,6 +725,13 @@ type kmsClient interface { Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error) } +// mrkClient is a client for managing multi-region keys. +type mrkClient interface { + ReplicateKey(context.Context, *kms.ReplicateKeyInput, ...func(*kms.Options)) (*kms.ReplicateKeyOutput, error) + UpdatePrimaryRegion(context.Context, *kms.UpdatePrimaryRegionInput, ...func(*kms.Options)) (*kms.UpdatePrimaryRegionOutput, error) + DescribeKey(context.Context, *kms.DescribeKeyInput, ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) +} + type stsClient interface { GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) } diff --git a/lib/auth/keystore/aws_kms_test.go b/lib/auth/keystore/aws_kms_test.go index 12fb7e89335cd..f5e02c9ec9db2 100644 --- a/lib/auth/keystore/aws_kms_test.go +++ b/lib/auth/keystore/aws_kms_test.go @@ -24,6 +24,7 @@ import ( "fmt" "slices" "strconv" + "strings" "testing" "time" @@ -147,7 +148,7 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) { err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) require.NoError(t, err) for _, key := range fakeKMS.keys { - if key.arn == otherClusterKeyARN { + if key.arn.String() == otherClusterKeyARN { assert.Equal(t, kmstypes.KeyStateEnabled, key.state) } else { assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state) @@ -262,6 +263,7 @@ func TestAWSKeyCreationParameters(t *testing.T) { HostUUID: "uuid", AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, awsKMSClient: fakeKMS, + mrkClient: fakeKMS, awsSTSClient: &fakeAWSSTSClient{ account: "123456789012", }, @@ -295,7 +297,7 @@ func TestAWSKeyCreationParameters(t *testing.T) { AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: "123456789012", AWSRegion: "us-west-2", - MultiRegion: struct{ Enabled bool }{ + MultiRegion: servicecfg.MultiRegionKeyStore{ Enabled: tc.multiRegion, }, Tags: tc.tags, @@ -351,11 +353,23 @@ func newFakeAWSKMSService(t *testing.T, clock clockwork.Clock, account string, r } type fakeAWSKMSKey struct { - arn string + arn arn.ARN privKeyPEM []byte tags []kmstypes.Tag creationDate time.Time state kmstypes.KeyState + region string + replicas []string +} + +func (f fakeAWSKMSKey) replicaArn(region string) string { + arn := f.arn + arn.Region = region + return arn.String() +} + +func (f fakeAWSKMSKey) hasReplica(region string) bool { + return region == f.region || slices.Contains(f.replicas, region) } func (f *fakeAWSKMSService) CreateKey(_ context.Context, input *kms.CreateKeyInput, _ ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { @@ -392,10 +406,11 @@ func (f *fakeAWSKMSService) CreateKey(_ context.Context, input *kms.CreateKeyInp return nil, trace.BadParameter("unsupported KeySpec %v", input.KeySpec) } f.keys = append(f.keys, &fakeAWSKMSKey{ - arn: a.String(), + arn: a, privKeyPEM: privKeyPEM, tags: input.Tags, creationDate: f.clock.Now(), + region: f.region, state: state, }) return &kms.CreateKeyOutput{ @@ -478,8 +493,11 @@ func (f *fakeAWSKMSService) ListKeys(_ context.Context, input *kms.ListKeysInput } } for ; i < len(f.keys) && len(output.Keys) < pageLimit; i++ { + if !f.keys[i].hasReplica(f.region) { + continue + } output.Keys = append(output.Keys, kmstypes.KeyListEntry{ - KeyArn: aws.String(f.keys[i].arn), + KeyArn: aws.String(f.keys[i].arn.String()), }) } if i < len(f.keys) { @@ -504,17 +522,36 @@ func (f *fakeAWSKMSService) DescribeKey(_ context.Context, input *kms.DescribeKe if err != nil { return nil, trace.Wrap(err) } - return &kms.DescribeKeyOutput{ + out := &kms.DescribeKeyOutput{ KeyMetadata: &kmstypes.KeyMetadata{ + KeyId: aws.String(key.arn.Resource), + Arn: aws.String(key.replicaArn(f.region)), CreationDate: aws.Time(key.creationDate), KeyState: key.state, }, - }, nil + } + if strings.HasPrefix(key.arn.Resource, "mrk-") { + out.KeyMetadata.MultiRegionConfiguration = &kmstypes.MultiRegionConfiguration{ + PrimaryKey: &kmstypes.MultiRegionKey{ + Arn: aws.String(key.arn.String()), + Region: &key.arn.Region, + }, + } + var replicas []kmstypes.MultiRegionKey + for _, replica := range key.replicas { + replicas = append(replicas, kmstypes.MultiRegionKey{ + Arn: aws.String(key.replicaArn(replica)), + Region: aws.String(replica), + }) + } + out.KeyMetadata.MultiRegionConfiguration.ReplicaKeys = replicas + } + return out, nil } func (f *fakeAWSKMSService) findKey(arn string) (*fakeAWSKMSKey, error) { - i := slices.IndexFunc(f.keys, func(k *fakeAWSKMSKey) bool { return k.arn == arn }) - if i < 0 { + i := slices.IndexFunc(f.keys, func(k *fakeAWSKMSKey) bool { return k.arn.String() == arn || k.arn.Resource == arn }) + if i < 0 || !f.keys[i].hasReplica(f.region) { return nil, &kmstypes.NotFoundException{ Message: aws.String(fmt.Sprintf("key %q not found", arn)), } @@ -532,6 +569,51 @@ func (f *fakeAWSKMSService) findKey(arn string) (*fakeAWSKMSKey, error) { return key, nil } +func (f *fakeAWSKMSService) ReplicateKey(ctx context.Context, in *kms.ReplicateKeyInput, _ ...func(*kms.Options)) (*kms.ReplicateKeyOutput, error) { + key, err := f.findKey(*in.KeyId) + if err != nil { + return nil, trace.Wrap(err) + } + if key.region != f.region { + return nil, &kmstypes.InvalidKeyUsageException{ + Message: aws.String("must use primary key for key replication"), + } + } + if key.hasReplica(*in.ReplicaRegion) { + return nil, &kmstypes.AlreadyExistsException{ + Message: aws.String(fmt.Sprintf("replicas %s already exists", *in.ReplicaRegion)), + } + } + key.replicas = append(key.replicas, *in.ReplicaRegion) + return &kms.ReplicateKeyOutput{ + ReplicaKeyMetadata: &kmstypes.KeyMetadata{ + Arn: aws.String(key.replicaArn(*in.ReplicaRegion)), + }, + }, nil +} + +func (f *fakeAWSKMSService) UpdatePrimaryRegion(ctx context.Context, in *kms.UpdatePrimaryRegionInput, _ ...func(*kms.Options)) (*kms.UpdatePrimaryRegionOutput, error) { + key, err := f.findKey(*in.KeyId) + if err != nil { + return nil, trace.Wrap(err) + } + if key.region != f.region { + return nil, &kmstypes.InvalidKeyUsageException{ + Message: aws.String("must use primary key for updating primary region"), + } + } + i := slices.Index(key.replicas, *in.PrimaryRegion) + if i == -1 { + return nil, &kmstypes.InvalidKeyUsageException{ + Message: aws.String("replica does not exist"), + } + } + key.replicas[i] = key.region + key.region = *in.PrimaryRegion + key.arn.Region = *in.PrimaryRegion + return &kms.UpdatePrimaryRegionOutput{}, nil +} + type fakeAWSSTSClient struct { account, arn, userID string } @@ -543,3 +625,177 @@ func (f *fakeAWSSTSClient) GetCallerIdentity(_ context.Context, _ *sts.GetCaller UserId: aws.String(f.userID), }, nil } + +func TestMultiRegionKeyReplication(t *testing.T) { + testAccount := "123456789" + testPrimary := "us-west-2" + testSecondary := "us-east-1" + testReplicas := []string{testSecondary, "us-east-2"} + + tests := []struct { + name string + config servicecfg.AWSKMSConfig + existingPrimary string + existingReplicas []string + expectedReplicas []string + expectedPrimary string + }{ + { + name: "backwards compatibility when no primary/replicas are configured", + config: servicecfg.AWSKMSConfig{ + AWSAccount: testAccount, + AWSRegion: testPrimary, + MultiRegion: servicecfg.MultiRegionKeyStore{ + Enabled: true, + }, + }, + existingReplicas: nil, + expectedReplicas: []string{}, + expectedPrimary: testPrimary, + }, + { + name: "replicas are created when specified from the primary region", + config: servicecfg.AWSKMSConfig{ + AWSAccount: testAccount, + AWSRegion: testPrimary, + MultiRegion: servicecfg.MultiRegionKeyStore{ + Enabled: true, + PrimaryRegion: testPrimary, + ReplicaRegions: testReplicas, + }, + }, + existingReplicas: nil, + expectedReplicas: testReplicas, + expectedPrimary: testPrimary, + }, + { + name: "replicas are not created from outside primary region", + config: servicecfg.AWSKMSConfig{ + AWSAccount: testAccount, + AWSRegion: testSecondary, + MultiRegion: servicecfg.MultiRegionKeyStore{ + Enabled: true, + PrimaryRegion: testPrimary, + ReplicaRegions: testReplicas, + }, + }, + existingReplicas: []string{testSecondary}, + expectedReplicas: []string{testSecondary}, + expectedPrimary: testPrimary, + }, + { + name: "primary region is updated from the existing primary region", + config: servicecfg.AWSKMSConfig{ + AWSAccount: testAccount, + AWSRegion: testPrimary, + MultiRegion: servicecfg.MultiRegionKeyStore{ + Enabled: true, + PrimaryRegion: testSecondary, + ReplicaRegions: []string{testPrimary}, + }, + }, + existingPrimary: testPrimary, + existingReplicas: []string{testSecondary}, + expectedReplicas: []string{testPrimary}, + expectedPrimary: testSecondary, + }, + { + name: "primary region is not updated from a non-primary region", + config: servicecfg.AWSKMSConfig{ + AWSAccount: testAccount, + AWSRegion: testSecondary, + MultiRegion: servicecfg.MultiRegionKeyStore{ + Enabled: true, + PrimaryRegion: testSecondary, + ReplicaRegions: testReplicas, + }, + }, + existingPrimary: testPrimary, + existingReplicas: testReplicas, + expectedReplicas: testReplicas, + expectedPrimary: testPrimary, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + fakeKMS := newFakeAWSKMSService(t, clock, testAccount, tc.config.AWSRegion, 1) + cluster, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ClusterName: "test-cluster"}) + require.NoError(t, err) + opts := &Options{ + ClusterName: cluster, + HostUUID: "uuid", + AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, + awsKMSClient: fakeKMS, + mrkClient: fakeKMS, + awsSTSClient: &fakeAWSSTSClient{ + account: testAccount, + }, + clockworkOverride: clock, + } + + existingPrimary := tc.existingPrimary + if existingPrimary == "" { + existingPrimary = testPrimary + } + fakeKMS.region = existingPrimary + primary, err := NewManager(ctx, &servicecfg.KeystoreConfig{ + AWSKMS: &servicecfg.AWSKMSConfig{ + AWSAccount: tc.config.AWSAccount, + AWSRegion: testPrimary, + MultiRegion: servicecfg.MultiRegionKeyStore{ + Enabled: true, + PrimaryRegion: existingPrimary, + ReplicaRegions: tc.existingReplicas, + }, + }, + }, opts) + require.NoError(t, err) + + kp, err := primary.NewTLSKeyPair(ctx, cluster.GetName(), cryptosuites.HostCATLS) + require.NoError(t, err, trace.DebugReport(err)) + key, err := parseAWSKMSKeyID(kp.Key) + require.NoError(t, err) + require.Equal(t, key.region, existingPrimary) + require.Contains(t, key.arn, "mrk-") + require.ElementsMatch(t, tc.existingReplicas, fakeKMS.keys[0].replicas) + + fakeKMS.region = tc.config.AWSRegion + mgr, err := NewManager(ctx, &servicecfg.KeystoreConfig{ + AWSKMS: &tc.config, + }, opts) + require.NoError(t, err) + + id, err := mgr.ApplyMultiRegionConfig(ctx, kp.Key) + require.NoError(t, err) + + key, err = parseAWSKMSKeyID(id) + require.NoError(t, err) + require.Equal(t, tc.expectedPrimary, key.region) + + out, err := fakeKMS.DescribeKey(ctx, &kms.DescribeKeyInput{ + KeyId: &key.id, + }) + require.NoError(t, err) + + mrc := out.KeyMetadata.MultiRegionConfiguration + if tc.expectedPrimary != "" { + require.Equal(t, + tc.expectedPrimary, + *mrc.PrimaryKey.Region, + ) + } + for _, replica := range tc.expectedReplicas { + require.True(t, slices.ContainsFunc(mrc.ReplicaKeys, func(key kmstypes.MultiRegionKey) bool { + return *key.Region == replica + }), "expected %s found in replicas %v", replica, mrc.ReplicaKeys) + } + for _, replica := range mrc.ReplicaKeys { + require.Contains(t, tc.expectedReplicas, *replica.Region) + } + }) + } + +} diff --git a/lib/auth/keystore/keystore_test.go b/lib/auth/keystore/keystore_test.go index 13e05a5e5070b..7094af2b89b89 100644 --- a/lib/auth/keystore/keystore_test.go +++ b/lib/auth/keystore/keystore_test.go @@ -541,13 +541,15 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { ClusterName: clusterName, }) require.NoError(t, err) + fakeKMS := newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", 100) baseOpts := Options{ ClusterName: clusterName, HostUUID: hostUUID, Logger: logger, AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - awsKMSClient: newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", 100), + awsKMSClient: fakeKMS, + mrkClient: fakeKMS, awsSTSClient: &fakeAWSSTSClient{ account: "123456789012", }, @@ -683,7 +685,7 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { AWSKMS: &servicecfg.AWSKMSConfig{ AWSAccount: "123456789012", AWSRegion: "us-west-2", - MultiRegion: struct{ Enabled bool }{ + MultiRegion: servicecfg.MultiRegionKeyStore{ Enabled: multiRegion, }, }, diff --git a/lib/auth/keystore/manager.go b/lib/auth/keystore/manager.go index 68073ae7f297f..4eb61c6b19b68 100644 --- a/lib/auth/keystore/manager.go +++ b/lib/auth/keystore/manager.go @@ -157,9 +157,11 @@ type Options struct { // FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested. FIPS bool - awsKMSClient kmsClient - awsSTSClient stsClient - kmsClient *kms.KeyManagementClient + awsKMSClient kmsClient + mrkClient mrkClient + awsSTSClient stsClient + kmsClient *kms.KeyManagementClient + clockworkOverride clockwork.Clock // GCPKMS uses a special fake clock that seemed more testable at the time. faketimeOverride faketime.Clock @@ -674,6 +676,21 @@ func (m *Manager) DeleteUnusedKeys(ctx context.Context, activeKeys [][]byte) err return trace.Wrap(m.backendForNewKeys.deleteUnusedKeys(ctx, activeKeys)) } +// ApplyMultiRegionConfig configures the given keyID with the current multi-region +// parameters and returns the updated keyID. This is currently only implemented +// for AWS KMS. +func (m *Manager) ApplyMultiRegionConfig(ctx context.Context, keyID []byte) ([]byte, error) { + backend, ok := m.backendForNewKeys.(*awsKMSKeystore) + if !ok { + return keyID, nil + } + keyID, err := backend.applyMultiRegionConfig(ctx, keyID) + if err != nil { + return nil, trace.Wrap(err) + } + return keyID, nil +} + // UsingHSMOrKMS returns true if the keystore is configured to use an HSM or KMS // when generating new keys. func (m *Manager) UsingHSMOrKMS() bool { diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 9f451ff488de7..eb6726f59fcfc 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -1191,9 +1191,10 @@ func applyAWSKMSConfig(kmsConfig *AWSKMS, cfg *servicecfg.Config) error { cfg.Auth.KeyStore.AWSKMS = &servicecfg.AWSKMSConfig{ AWSAccount: kmsConfig.Account, AWSRegion: kmsConfig.Region, - MultiRegion: kmsConfig.MultiRegion, Tags: kmsConfig.Tags, + MultiRegion: kmsConfig.MultiRegion, } + return nil } diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 24ab642238871..76e754f4a0b0c 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -921,10 +921,7 @@ type AWSKMS struct { // Region is the AWS region to use. Region string `yaml:"region"` // MultiRegion contains configuration for multi-region AWS KMS. - MultiRegion struct { - // Enabled configures new keys to be multi-region. - Enabled bool - } `yaml:"multi_region,omitempty"` + MultiRegion servicecfg.MultiRegionKeyStore `yaml:"multi_region,omitempty"` // Tags are key/value pairs used as AWS resource tags. The 'TeleportCluster' // tag is added automatically if not specified in the set of tags. Changing tags // after Teleport has already created KMS keys may require manually updating diff --git a/lib/service/servicecfg/auth.go b/lib/service/servicecfg/auth.go index b3719d684f015..6288b1bb1e600 100644 --- a/lib/service/servicecfg/auth.go +++ b/lib/service/servicecfg/auth.go @@ -302,13 +302,10 @@ func (cfg *GCPKMSConfig) CheckAndSetDefaults() error { type AWSKMSConfig struct { // AWSAccount is the AWS account ID where the keys will reside. AWSAccount string - // AWSRegion is the AWS region where the keys will reside. + // AWSRegion is the region used for KMS key operations. AWSRegion string // MultiRegion contains configuration for multi-region AWS KMS. - MultiRegion struct { - // Enabled configures new keys to be multi-region. - Enabled bool - } + MultiRegion MultiRegionKeyStore // Tags are key/value pairs used as AWS resource tags. The 'TeleportCluster' // tag is added automatically if not specified in the set of tags. Changing tags // after Teleport has already created KMS keys may require manually updating @@ -327,3 +324,13 @@ func (c *AWSKMSConfig) CheckAndSetDefaults() error { } return nil } + +// MultiRegionKeyStore contains configuration for a multi-region keystore +type MultiRegionKeyStore struct { + // Enabled configures new keys to be multi-region. + Enabled bool `yaml:"enabled"` + // PrimaryRegion is the region the primary key is located. + PrimaryRegion string `yaml:"primary_region"` + // ReplicaRegions is a list of regions keys will be replicated to. + ReplicaRegions []string `yaml:"replica_regions"` +} From e8c97f21bd20c71c8bfa5f06868ac9f5cb5f5542 Mon Sep 17 00:00:00 2001 From: David Boslee Date: Fri, 30 May 2025 07:35:26 -0400 Subject: [PATCH 2/2] keystore: retry describe key when applying multi-region kms config (#55274) --- lib/auth/keystore/aws_kms.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/auth/keystore/aws_kms.go b/lib/auth/keystore/aws_kms.go index b5be361a1d0e0..d82962265c4a0 100644 --- a/lib/auth/keystore/aws_kms.go +++ b/lib/auth/keystore/aws_kms.go @@ -566,8 +566,16 @@ func (a *awsKMSKeystore) applyMRKConfig(ctx context.Context, key awsKMSKeyID) ([ } client := a.mrk - describeKeyOut, err := client.DescribeKey(ctx, &kms.DescribeKeyInput{ - KeyId: aws.String(key.id), + var describeKeyOut *kms.DescribeKeyOutput + err := a.retryOnConsistencyError(ctx, func(ctx context.Context) error { + var err error + describeKeyOut, err = client.DescribeKey(ctx, &kms.DescribeKeyInput{ + KeyId: aws.String(key.id), + }) + if err != nil { + return trace.Wrap(err) + } + return nil }) if err != nil { return nil, trace.Wrap(err)