Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 79 additions & 20 deletions lib/auth/keystore/aws_kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
const (
awskmsPrefix = "awskms:"
clusterTagKey = "TeleportCluster"
awsOAEPHash = crypto.SHA256

pendingKeyBaseRetryInterval = time.Second / 2
pendingKeyMaxRetryInterval = 4 * time.Second
Expand Down Expand Up @@ -163,12 +164,19 @@ func (a *awsKMSKeystore) keyTypeDescription() string {
return fmt.Sprintf("AWS KMS keys in account %s and region %s", a.awsAccount, a.awsRegion)
}

// generateSigner creates a new private key and returns its identifier and a crypto.Signer. The returned
// identifier can be passed to getSigner later to get an equivalent crypto.Signer.
func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosuites.Algorithm) ([]byte, crypto.Signer, error) {
func (u keyUsage) toAWS() kmstypes.KeyUsageType {
switch u {
case keyUsageDecrypt:
return kmstypes.KeyUsageTypeEncryptDecrypt
default:
return kmstypes.KeyUsageTypeSignVerify
}
}

func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites.Algorithm, usage keyUsage) (awsKMSKeyID, error) {
alg, err := awsAlgorithm(algorithm)
if err != nil {
return nil, nil, trace.Wrap(err)
return awsKMSKeyID{}, trace.Wrap(err)
}

a.logger.InfoContext(ctx, "Creating new AWS KMS keypair.",
Expand All @@ -186,34 +194,59 @@ func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosui
output, err := a.kms.CreateKey(ctx, &kms.CreateKeyInput{
Description: aws.String("Teleport CA key"),
KeySpec: alg,
KeyUsage: kmstypes.KeyUsageTypeSignVerify,
KeyUsage: usage.toAWS(),
Tags: tags,
MultiRegion: aws.Bool(a.multiRegionEnabled),
})
if err != nil {
return nil, nil, trace.Wrap(err)
return awsKMSKeyID{}, trace.Wrap(err)
}
if output.KeyMetadata == nil {
return nil, nil, trace.Errorf("KeyMetadata of generated key is nil")
return awsKMSKeyID{}, trace.Errorf("KeyMetadata of generated key is nil")
}
keyARN := aws.ToString(output.KeyMetadata.Arn)
key, err := keyIDFromArn(keyARN)
if err != nil {
return awsKMSKeyID{}, trace.Wrap(err)
}

return key, nil
}

// generateSigner creates a new private key and returns its identifier and a crypto.Signer. The returned
// identifier can be passed to getSigner later to get an equivalent crypto.Signer.
func (a *awsKMSKeystore) generateSigner(ctx context.Context, algorithm cryptosuites.Algorithm) ([]byte, crypto.Signer, error) {
key, err := a.generateKey(ctx, algorithm, keyUsageSign)
if err != nil {
return nil, nil, trace.Wrap(err)
}
keyID, err := a.applyMRKConfig(ctx, key)
if err != nil {
return nil, nil, trace.Wrap(err)
}
signer, err := a.newSigner(ctx, key)
signer, err := a.newKMSKey(ctx, key)
if err != nil {
return nil, nil, trace.Wrap(err)
}
return keyID, signer, nil
}

func (a *awsKMSKeystore) generateDecrypter(ctx context.Context, alg cryptosuites.Algorithm) (keyID []byte, decrypter crypto.Decrypter, hash crypto.Hash, err error) {
return nil, nil, crypto.SHA256, trace.NotImplemented("decryption not yet supported for AWS KMS key store")
// generateDecrypter creates a new private key and returns its identifier and a crypto.Decrypter. The returned
// identifier can be passed to getDecrypter later to get an equivalent crypto.Decrypter.
func (a *awsKMSKeystore) generateDecrypter(ctx context.Context, algorithm cryptosuites.Algorithm) ([]byte, crypto.Decrypter, crypto.Hash, error) {
key, err := a.generateKey(ctx, algorithm, keyUsageDecrypt)
if err != nil {
return nil, nil, awsOAEPHash, trace.Wrap(err)
}
keyID, err := a.applyMRKConfig(ctx, key)
if err != nil {
return nil, nil, awsOAEPHash, trace.Wrap(err)
}
decrypter, err := a.newKMSKey(ctx, key)
if err != nil {
return nil, nil, awsOAEPHash, trace.Wrap(err)
}
return keyID, decrypter, awsOAEPHash, nil
}

func awsAlgorithm(alg cryptosuites.Algorithm) (kmstypes.KeySpec, error) {
Expand All @@ -232,21 +265,25 @@ func (a *awsKMSKeystore) getSigner(ctx context.Context, rawKey []byte, publicKey
if err != nil {
return nil, trace.Wrap(err)
}
return a.newSignerWithPublicKey(ctx, key, publicKey)
return a.newKMSKeyWithPublicKey(ctx, key, publicKey)
}

// getDecrypter returns a crypto.Signer for the given key identifier, if it is found.
// getDecrypter returns a crypto.Decrypter for the given key identifier, if it is found.
func (a *awsKMSKeystore) getDecrypter(ctx context.Context, rawKey []byte, publicKey crypto.PublicKey, hash crypto.Hash) (crypto.Decrypter, error) {
return nil, trace.NotImplemented("decryption not yet supported for AWS KMS key store")
key, err := parseAWSKMSKeyID(rawKey)
if err != nil {
return nil, trace.Wrap(err)
}
return a.newKMSKeyWithPublicKey(ctx, key, publicKey)
}

type awsKMSSigner struct {
type awsKMSKey struct {
key awsKMSKeyID
pub crypto.PublicKey
kms kmsClient
}

func (a *awsKMSKeystore) newSigner(ctx context.Context, key awsKMSKeyID) (*awsKMSSigner, error) {
func (a *awsKMSKeystore) newKMSKey(ctx context.Context, key awsKMSKeyID) (*awsKMSKey, error) {
var pubkeyDER []byte
err := a.retryOnConsistencyError(ctx, func(ctx context.Context) error {
a.logger.DebugContext(ctx, "Fetching public key", "key_arn", key.arn)
Expand All @@ -268,7 +305,7 @@ func (a *awsKMSKeystore) newSigner(ctx context.Context, key awsKMSKeyID) (*awsKM
if err != nil {
return nil, trace.Wrap(err, "unexpected error parsing public key der")
}
return a.newSignerWithPublicKey(ctx, key, pub)
return a.newKMSKeyWithPublicKey(ctx, key, pub)
}

// retryOnConsistencyError handles retrying KMS key operations that may fail
Expand Down Expand Up @@ -313,21 +350,21 @@ func (a *awsKMSKeystore) retryOnConsistencyError(ctx context.Context, fn func(ct
}
}

func (a *awsKMSKeystore) newSignerWithPublicKey(_ context.Context, key awsKMSKeyID, publicKey crypto.PublicKey) (*awsKMSSigner, error) {
return &awsKMSSigner{
func (a *awsKMSKeystore) newKMSKeyWithPublicKey(_ context.Context, key awsKMSKeyID, publicKey crypto.PublicKey) (*awsKMSKey, error) {
return &awsKMSKey{
key: key,
pub: publicKey,
kms: a.kms,
}, nil
}

// Public returns the public key for the signer.
func (a *awsKMSSigner) Public() crypto.PublicKey {
func (a *awsKMSKey) Public() crypto.PublicKey {
return a.pub
}

// Sign signs the message digest.
func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
func (a *awsKMSKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
var signingAlg kmstypes.SigningAlgorithmSpec
switch opts.HashFunc() {
case crypto.SHA256:
Expand Down Expand Up @@ -363,6 +400,27 @@ func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpt
return output.Signature, nil
}

// Decrypt decrypts data encrypted with the public key
func (a *awsKMSKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
var encAlg kmstypes.EncryptionAlgorithmSpec
switch a.pub.(type) {
case *rsa.PublicKey:
encAlg = kmstypes.EncryptionAlgorithmSpecRsaesOaepSha256
default:
return nil, trace.BadParameter("unsupported key algorithm for AWS KMS decryption")
}

output, err := a.kms.Decrypt(context.TODO(), &kms.DecryptInput{
KeyId: aws.String(a.key.id),
CiphertextBlob: ciphertext,
EncryptionAlgorithm: encAlg,
})
if err != nil {
return nil, trace.Wrap(err)
}
return output.Plaintext, nil
}

// deleteKey deletes the given key from the KeyStore.
func (a *awsKMSKeystore) deleteKey(ctx context.Context, rawKey []byte) error {
keyID, err := parseAWSKMSKeyID(rawKey)
Expand Down Expand Up @@ -748,6 +806,7 @@ type kmsClient interface {
DescribeKey(context.Context, *kms.DescribeKeyInput, ...func(*kms.Options)) (*kms.DescribeKeyOutput, error)
ListResourceTags(context.Context, *kms.ListResourceTagsInput, ...func(*kms.Options)) (*kms.ListResourceTagsOutput, error)
Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error)
Decrypt(context.Context, *kms.DecryptInput, ...func(*kms.Options)) (*kms.DecryptOutput, error)
}

// mrkClient is a client for managing multi-region keys.
Expand Down
33 changes: 13 additions & 20 deletions lib/auth/keystore/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"crypto/rsa"
"crypto/sha256"
"errors"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -301,11 +300,8 @@ func TestManager(t *testing.T) {
require.Equal(t, backendDesc.expectedKeyType, jwtKeyPair.PrivateKeyType)

encKeyPair, err := manager.NewEncryptionKeyPair(ctx, cryptosuites.RecordingKeyWrapping)
// TODO (eriktate): remove once decryption with AWS is added
if !strings.Contains(backendDesc.name, "aws_kms") {
require.NoError(t, err)
require.Equal(t, backendDesc.expectedKeyType, encKeyPair.PrivateKeyType)
}
require.NoError(t, err)
require.Equal(t, backendDesc.expectedKeyType, encKeyPair.PrivateKeyType)

// Test a CA with multiple active keypairs. Each element of ActiveKeys
// includes a keypair generated above and a PKCS11 keypair with a
Expand Down Expand Up @@ -347,23 +343,20 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
require.Equal(t, jwtKeyPair.PublicKey, pubkeyPem)

// TODO (eriktate): remove once decryption with AWS is added
if !strings.Contains(backendDesc.name, "aws_kms") {
decrypter, err := manager.GetDecrypter(ctx, encKeyPair)
decrypter, err := manager.GetDecrypter(ctx, encKeyPair)

require.NoError(t, err)
require.NotNil(t, decrypter)
require.NoError(t, err)
require.NotNil(t, decrypter)

// Try encrypting and decrypting some data
msg := []byte("teleport")
require.NoError(t, err)
ciphertext, err := encKeyPair.EncryptOAEP(msg)
require.NoError(t, err)
// Try encrypting and decrypting some data
msg := []byte("teleport")
require.NoError(t, err)
ciphertext, err := encKeyPair.EncryptOAEP(msg)
require.NoError(t, err)

plaintext, err := decrypter.Decrypt(rand.Reader, ciphertext, nil)
require.NoError(t, err)
require.Equal(t, msg, plaintext)
}
plaintext, err := decrypter.Decrypt(rand.Reader, ciphertext, nil)
require.NoError(t, err)
require.Equal(t, msg, plaintext)

// Try signing an SSH cert.
sshCert := ssh.Certificate{
Expand Down
Loading