diff --git a/lib/auth/keystore/aws_kms.go b/lib/auth/keystore/aws_kms.go index 8d07f3fa785c8..8259f6393d01c 100644 --- a/lib/auth/keystore/aws_kms.go +++ b/lib/auth/keystore/aws_kms.go @@ -53,6 +53,7 @@ import ( const ( awskmsPrefix = "awskms:" clusterTagKey = "TeleportCluster" + awsOAEPHash = crypto.SHA256 pendingKeyBaseRetryInterval = time.Second / 2 pendingKeyMaxRetryInterval = 4 * time.Second @@ -163,12 +164,19 @@ func (a *awsKMSKeystore) keyTypeDescription() string { return fmt.Sprintf("AWS KMS keys in account %s and region %s", a.awsAccount, a.awsRegion) } -// generateSigner creates a new private key and returns its identifier and a crypto.Signer. The returned -// identifier can be passed to getSigner later to get an equivalent crypto.Signer. -func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosuites.Algorithm) ([]byte, crypto.Signer, error) { +func (u keyUsage) toAWS() kmstypes.KeyUsageType { + switch u { + case keyUsageDecrypt: + return kmstypes.KeyUsageTypeEncryptDecrypt + default: + return kmstypes.KeyUsageTypeSignVerify + } +} + +func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites.Algorithm, usage keyUsage) (awsKMSKeyID, error) { alg, err := awsAlgorithm(algorithm) if err != nil { - return nil, nil, trace.Wrap(err) + return awsKMSKeyID{}, trace.Wrap(err) } a.logger.InfoContext(ctx, "Creating new AWS KMS keypair.", @@ -186,18 +194,29 @@ func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosui output, err := a.kms.CreateKey(ctx, &kms.CreateKeyInput{ Description: aws.String("Teleport CA key"), KeySpec: alg, - KeyUsage: kmstypes.KeyUsageTypeSignVerify, + KeyUsage: usage.toAWS(), Tags: tags, MultiRegion: aws.Bool(a.multiRegionEnabled), }) if err != nil { - return nil, nil, trace.Wrap(err) + return awsKMSKeyID{}, trace.Wrap(err) } if output.KeyMetadata == nil { - return nil, nil, trace.Errorf("KeyMetadata of generated key is nil") + return awsKMSKeyID{}, trace.Errorf("KeyMetadata of generated key is nil") } keyARN := aws.ToString(output.KeyMetadata.Arn) key, err := keyIDFromArn(keyARN) + if err != nil { + return awsKMSKeyID{}, trace.Wrap(err) + } + + return key, nil +} + +// generateSigner creates a new private key and returns its identifier and a crypto.Signer. The returned +// identifier can be passed to getSigner later to get an equivalent crypto.Signer. +func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosuites.Algorithm) ([]byte, crypto.Signer, error) { + key, err := a.generateKey(ctx, algorithm, keyUsageSign) if err != nil { return nil, nil, trace.Wrap(err) } @@ -205,15 +224,29 @@ func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosui if err != nil { return nil, nil, trace.Wrap(err) } - signer, err := a.newSigner(ctx, key) + signer, err := a.newKMSKey(ctx, key) if err != nil { return nil, nil, trace.Wrap(err) } return keyID, signer, nil } -func (a *awsKMSKeystore) generateDecrypter(ctx context.Context, alg cryptosuites.Algorithm) (keyID []byte, decrypter crypto.Decrypter, hash crypto.Hash, err error) { - return nil, nil, crypto.SHA256, trace.NotImplemented("decryption not yet supported for AWS KMS key store") +// generateDecrypter creates a new private key and returns its identifier and a crypto.Decrypter. The returned +// identifier can be passed to getDecrypter later to get an equivalent crypto.Decrypter. +func (a *awsKMSKeystore) generateDecrypter(ctx context.Context, algorithm cryptosuites.Algorithm) ([]byte, crypto.Decrypter, crypto.Hash, error) { + key, err := a.generateKey(ctx, algorithm, keyUsageDecrypt) + if err != nil { + return nil, nil, awsOAEPHash, trace.Wrap(err) + } + keyID, err := a.applyMRKConfig(ctx, key) + if err != nil { + return nil, nil, awsOAEPHash, trace.Wrap(err) + } + decrypter, err := a.newKMSKey(ctx, key) + if err != nil { + return nil, nil, awsOAEPHash, trace.Wrap(err) + } + return keyID, decrypter, awsOAEPHash, nil } func awsAlgorithm(alg cryptosuites.Algorithm) (kmstypes.KeySpec, error) { @@ -232,21 +265,25 @@ func (a *awsKMSKeystore) getSigner(ctx context.Context, rawKey []byte, publicKey if err != nil { return nil, trace.Wrap(err) } - return a.newSignerWithPublicKey(ctx, key, publicKey) + return a.newKMSKeyWithPublicKey(ctx, key, publicKey) } -// getDecrypter returns a crypto.Signer for the given key identifier, if it is found. +// getDecrypter returns a crypto.Decrypter for the given key identifier, if it is found. func (a *awsKMSKeystore) getDecrypter(ctx context.Context, rawKey []byte, publicKey crypto.PublicKey, hash crypto.Hash) (crypto.Decrypter, error) { - return nil, trace.NotImplemented("decryption not yet supported for AWS KMS key store") + key, err := parseAWSKMSKeyID(rawKey) + if err != nil { + return nil, trace.Wrap(err) + } + return a.newKMSKeyWithPublicKey(ctx, key, publicKey) } -type awsKMSSigner struct { +type awsKMSKey struct { key awsKMSKeyID pub crypto.PublicKey kms kmsClient } -func (a *awsKMSKeystore) newSigner(ctx context.Context, key awsKMSKeyID) (*awsKMSSigner, error) { +func (a *awsKMSKeystore) newKMSKey(ctx context.Context, key awsKMSKeyID) (*awsKMSKey, error) { var pubkeyDER []byte err := a.retryOnConsistencyError(ctx, func(ctx context.Context) error { a.logger.DebugContext(ctx, "Fetching public key", "key_arn", key.arn) @@ -268,7 +305,7 @@ func (a *awsKMSKeystore) newSigner(ctx context.Context, key awsKMSKeyID) (*awsKM if err != nil { return nil, trace.Wrap(err, "unexpected error parsing public key der") } - return a.newSignerWithPublicKey(ctx, key, pub) + return a.newKMSKeyWithPublicKey(ctx, key, pub) } // retryOnConsistencyError handles retrying KMS key operations that may fail @@ -313,8 +350,8 @@ func (a *awsKMSKeystore) retryOnConsistencyError(ctx context.Context, fn func(ct } } -func (a *awsKMSKeystore) newSignerWithPublicKey(_ context.Context, key awsKMSKeyID, publicKey crypto.PublicKey) (*awsKMSSigner, error) { - return &awsKMSSigner{ +func (a *awsKMSKeystore) newKMSKeyWithPublicKey(_ context.Context, key awsKMSKeyID, publicKey crypto.PublicKey) (*awsKMSKey, error) { + return &awsKMSKey{ key: key, pub: publicKey, kms: a.kms, @@ -322,12 +359,12 @@ func (a *awsKMSKeystore) newSignerWithPublicKey(_ context.Context, key awsKMSKey } // Public returns the public key for the signer. -func (a *awsKMSSigner) Public() crypto.PublicKey { +func (a *awsKMSKey) Public() crypto.PublicKey { return a.pub } // Sign signs the message digest. -func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { +func (a *awsKMSKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { var signingAlg kmstypes.SigningAlgorithmSpec switch opts.HashFunc() { case crypto.SHA256: @@ -363,6 +400,27 @@ func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpt return output.Signature, nil } +// Decrypt decrypts data encrypted with the public key +func (a *awsKMSKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { + var encAlg kmstypes.EncryptionAlgorithmSpec + switch a.pub.(type) { + case *rsa.PublicKey: + encAlg = kmstypes.EncryptionAlgorithmSpecRsaesOaepSha256 + default: + return nil, trace.BadParameter("unsupported key algorithm for AWS KMS decryption") + } + + output, err := a.kms.Decrypt(context.TODO(), &kms.DecryptInput{ + KeyId: aws.String(a.key.id), + CiphertextBlob: ciphertext, + EncryptionAlgorithm: encAlg, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return output.Plaintext, nil +} + // deleteKey deletes the given key from the KeyStore. func (a *awsKMSKeystore) deleteKey(ctx context.Context, rawKey []byte) error { keyID, err := parseAWSKMSKeyID(rawKey) @@ -748,6 +806,7 @@ type kmsClient interface { DescribeKey(context.Context, *kms.DescribeKeyInput, ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) ListResourceTags(context.Context, *kms.ListResourceTagsInput, ...func(*kms.Options)) (*kms.ListResourceTagsOutput, error) Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error) + Decrypt(context.Context, *kms.DecryptInput, ...func(*kms.Options)) (*kms.DecryptOutput, error) } // mrkClient is a client for managing multi-region keys. diff --git a/lib/auth/keystore/keystore_test.go b/lib/auth/keystore/keystore_test.go index ea6e441624250..3f447744f3717 100644 --- a/lib/auth/keystore/keystore_test.go +++ b/lib/auth/keystore/keystore_test.go @@ -27,7 +27,6 @@ import ( "crypto/rsa" "crypto/sha256" "errors" - "strings" "testing" "time" @@ -301,11 +300,8 @@ func TestManager(t *testing.T) { require.Equal(t, backendDesc.expectedKeyType, jwtKeyPair.PrivateKeyType) encKeyPair, err := manager.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping) - // TODO (eriktate): remove once decryption with AWS is added - if !strings.Contains(backendDesc.name, "aws_kms") { - require.NoError(t, err) - require.Equal(t, backendDesc.expectedKeyType, encKeyPair.PrivateKeyType) - } + require.NoError(t, err) + require.Equal(t, backendDesc.expectedKeyType, encKeyPair.PrivateKeyType) // Test a CA with multiple active keypairs. Each element of ActiveKeys // includes a keypair generated above and a PKCS11 keypair with a @@ -347,23 +343,20 @@ func TestManager(t *testing.T) { require.NoError(t, err) require.Equal(t, jwtKeyPair.PublicKey, pubkeyPem) - // TODO (eriktate): remove once decryption with AWS is added - if !strings.Contains(backendDesc.name, "aws_kms") { - decrypter, err := manager.GetDecrypter(ctx, encKeyPair) + decrypter, err := manager.GetDecrypter(ctx, encKeyPair) - require.NoError(t, err) - require.NotNil(t, decrypter) + require.NoError(t, err) + require.NotNil(t, decrypter) - // Try encrypting and decrypting some data - msg := []byte("teleport") - require.NoError(t, err) - ciphertext, err := encKeyPair.EncryptOAEP(msg) - require.NoError(t, err) + // Try encrypting and decrypting some data + msg := []byte("teleport") + require.NoError(t, err) + ciphertext, err := encKeyPair.EncryptOAEP(msg) + require.NoError(t, err) - plaintext, err := decrypter.Decrypt(rand.Reader, ciphertext, nil) - require.NoError(t, err) - require.Equal(t, msg, plaintext) - } + plaintext, err := decrypter.Decrypt(rand.Reader, ciphertext, nil) + require.NoError(t, err) + require.Equal(t, msg, plaintext) // Try signing an SSH cert. sshCert := ssh.Certificate{