Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
.idea
go.work
go.work.sum
.env
.envrc
108 changes: 59 additions & 49 deletions wrappers/awskms/awskms.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import (
"os"
"sync/atomic"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/kms"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/go-secure-stdlib/awsutil/v2"
)

// These constants contain the accepted env vars; the Vault one is for backwards compat
const (
EnvAwsKmsWrapperKeyId = "AWSKMS_WRAPPER_KEY_ID"
EnvVaultAwsKmsSealKeyId = "VAULT_AWSKMS_SEAL_KEY_ID"
EnvAwsKmsWrapperKeyId = "AWSKMS_WRAPPER_KEY_ID"
EnvVaultAwsKmsSealKeyId = "VAULT_AWSKMS_SEAL_KEY_ID"
DeprecatedEnvAwsKmsEndpoint = "AWS_KMS_ENDPOINT"
EnvAwsKmsEndpoint = "AWSKMS_ENDPOINT"
)

const (
Expand Down Expand Up @@ -52,11 +52,18 @@ type Wrapper struct {

currentKeyId *atomic.Value

client kmsiface.KMSAPI
client KmsApi

logger hclog.Logger
}

// KmsApi defines the functionality expected to be implemented by the AWS SDK v2 kms package.
type KmsApi interface {
Encrypt(ctx context.Context, input *kms.EncryptInput, opts ...func(*kms.Options)) (*kms.EncryptOutput, error)
Decrypt(ctx context.Context, input *kms.DecryptInput, opts ...func(*kms.Options)) (*kms.DecryptOutput, error)
DescribeKey(ctx context.Context, inpput *kms.DescribeKeyInput, opts ...func(*kms.Options)) (*kms.DescribeKeyOutput, error)
}
Comment thread
johanbrandhorst marked this conversation as resolved.

// Ensure that we are implementing Wrapper
var _ wrapping.Wrapper = (*Wrapper)(nil)

Expand All @@ -77,7 +84,7 @@ func NewWrapper() *Wrapper {
// * Passed in config map
// * Instance metadata role (access key and secret key)
// * Default values
func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) {
func (k *Wrapper) SetConfig(ctx context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, err
Expand All @@ -103,7 +110,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin
k.currentKeyId.Store(k.keyId)

// Please see GetRegion for an explanation of the order in which region is parsed.
k.region, err = awsutil.GetRegion(opts.withRegion)
k.region, err = awsutil.GetRegion(ctx, opts.withRegion)
if err != nil {
return nil, err
}
Expand All @@ -119,31 +126,35 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin
k.roleArn = opts.withRoleArn

if !opts.withDisallowEnvVars {
k.endpoint = os.Getenv("AWS_KMS_ENDPOINT")
ep := os.Getenv(EnvAwsKmsEndpoint)
if ep == "" {
ep = os.Getenv(DeprecatedEnvAwsKmsEndpoint)
}
k.endpoint = ep
}
if k.endpoint == "" {
k.endpoint = opts.withEndpoint
}

// Check and set k.client
if k.client == nil {
client, err := k.GetAwsKmsClient()
client, err := k.GetAwsKmsClient(ctx)
if err != nil {
return nil, fmt.Errorf("error initializing AWS KMS wrapping client: %w", err)
}

if !k.keyNotRequired {
// Test the client connection using provided key ID
keyInfo, err := client.DescribeKey(&kms.DescribeKeyInput{
KeyId: aws.String(k.keyId),
keyInfo, err := client.DescribeKey(ctx, &kms.DescribeKeyInput{
KeyId: &k.keyId,
})
if err != nil {
return nil, fmt.Errorf("error fetching AWS KMS wrapping key information: %w", err)
}
if keyInfo == nil || keyInfo.KeyMetadata == nil || keyInfo.KeyMetadata.KeyId == nil {
return nil, errors.New("no key information returned")
}
k.currentKeyId.Store(aws.StringValue(keyInfo.KeyMetadata.KeyId))
k.currentKeyId.Store(*keyInfo.KeyMetadata.KeyId)
Comment thread
johanbrandhorst marked this conversation as resolved.
}

k.client = client
Expand Down Expand Up @@ -174,7 +185,7 @@ func (k *Wrapper) KeyId(_ context.Context) (string, error) {
// Encrypt is used to encrypt the master key using the the AWS CMK.
// This returns the ciphertext, and/or any errors from this
// call. This should be called after the KMS client has been instantiated.
func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) {
func (k *Wrapper) Encrypt(ctx context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) {
if plaintext == nil {
return nil, fmt.Errorf("given plaintext for encryption is nil")
}
Expand All @@ -189,10 +200,10 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
}

input := &kms.EncryptInput{
KeyId: aws.String(k.keyId),
KeyId: &k.keyId,
Plaintext: env.Key,
}
output, err := k.client.Encrypt(input)
output, err := k.client.Encrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("error encrypting data: %w", err)
}
Expand All @@ -203,8 +214,8 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
// used for encryption. This is helpful if you are looking to reencyrpt
// your data when it is not using the latest key id. See these docs relating
// to key rotation https://docs.aws.amazon.com/kms/latest/developerguide/rotate-keys.html
keyId := aws.StringValue(output.KeyId)
k.currentKeyId.Store(keyId)
keyId := output.KeyId
k.currentKeyId.Store(*keyId)

ret := &wrapping.BlobInfo{
Ciphertext: env.Ciphertext,
Expand All @@ -214,7 +225,7 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
// Even though we do not use the key id during decryption, store it
// to know exactly the specific key used in encryption in case we
// want to rewrap older entries
KeyId: keyId,
KeyId: *keyId,
WrappedKey: output.CiphertextBlob,
},
}
Expand All @@ -223,7 +234,7 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
}

// Decrypt is used to decrypt the ciphertext. This should be called after Init.
func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapping.Option) ([]byte, error) {
func (k *Wrapper) Decrypt(ctx context.Context, in *wrapping.BlobInfo, opt ...wrapping.Option) ([]byte, error) {
if in == nil {
return nil, fmt.Errorf("given input for decryption is nil")
}
Expand All @@ -242,7 +253,7 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp
CiphertextBlob: in.Ciphertext,
}

output, err := k.client.Decrypt(input)
output, err := k.client.Decrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("error decrypting data: %w", err)
}
Expand All @@ -254,7 +265,7 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp
input := &kms.DecryptInput{
CiphertextBlob: in.KeyInfo.WrappedKey,
}
output, err := k.client.Decrypt(input)
output, err := k.client.Decrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("error decrypting data encryption key: %w", err)
}
Expand All @@ -277,48 +288,47 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp
}

// Client returns the AWS KMS client used by the wrapper.
func (k *Wrapper) Client() kmsiface.KMSAPI {
func (k *Wrapper) Client() KmsApi {
return k.client
}

// GetAwsKmsClient returns an instance of the KMS client.
func (k *Wrapper) GetAwsKmsClient() (*kms.KMS, error) {
credsConfig := &awsutil.CredentialsConfig{}

credsConfig.AccessKey = k.accessKey
credsConfig.SecretKey = k.secretKey
credsConfig.SessionToken = k.sessionToken
credsConfig.Filename = k.sharedCredsFilename
credsConfig.Profile = k.sharedCredsProfile
credsConfig.RoleARN = k.roleArn
credsConfig.RoleSessionName = k.roleSessionName
credsConfig.WebIdentityTokenFile = k.webIdentityTokenFile
credsConfig.Region = k.region
credsConfig.Logger = k.logger

credsConfig.HTTPClient = cleanhttp.DefaultClient()

creds, err := credsConfig.GenerateCredentialChain()
func (k *Wrapper) GetAwsKmsClient(ctx context.Context) (*kms.Client, error) {
credsConfig := &awsutil.CredentialsConfig{
AccessKey: k.accessKey,
SecretKey: k.secretKey,
SessionToken: k.sessionToken,
Filename: k.sharedCredsFilename,
Profile: k.sharedCredsProfile,
RoleARN: k.roleArn,
RoleSessionName: k.roleSessionName,
WebIdentityTokenFile: k.webIdentityTokenFile,
Region: k.region,
Logger: k.logger,
HTTPClient: cleanhttp.DefaultClient(),
}
Comment thread
johanbrandhorst marked this conversation as resolved.

creds, err := credsConfig.GenerateCredentialChain(ctx)
if err != nil {
return nil, err
}

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
HTTPClient: cleanhttp.DefaultClient(),
clientOpts := []func(*config.LoadOptions) error{
config.WithCredentialsProvider(creds.Credentials),
config.WithRegion(k.region),
config.WithHTTPClient(cleanhttp.DefaultClient()),
}

if k.endpoint != "" {
awsConfig.Endpoint = aws.String(k.endpoint)
clientOpts = append(clientOpts, config.WithBaseEndpoint(k.endpoint))
}

sess, err := session.NewSession(awsConfig)
cfg, err := config.LoadDefaultConfig(ctx, clientOpts...)
if err != nil {
return nil, err
}

client := kms.New(sess)
client := kms.NewFromConfig(cfg)

return client, nil
}
Loading
Loading