Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 104 additions & 92 deletions wrappers/awskms/awskms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import (
"testing"

wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/stretchr/testify/suite"

"go.uber.org/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const (
Expand All @@ -20,56 +19,50 @@ const (
envAwsRegion = "AWS_REGION"
)

type AwsKmsSuite struct {
suite.Suite
ctrl *gomock.Controller
wrapperWithMock *Wrapper
}

func TestSuite(t *testing.T) {
suite.Run(t, new(AwsKmsSuite))
}

func (s *AwsKmsSuite) SetupSubTest() {
s.ctrl = gomock.NewController(s.T())
s.wrapperWithMock = NewWrapper()
s.wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}
}

func (s *AwsKmsSuite) TestSetConfig() {
func TestSetConfig(t *testing.T) {
// Works around lack of AWS_REGION var in CI
if os.Getenv(envAwsRegion) == "" {
os.Setenv(envAwsRegion, "us-west-2")
defer os.Setenv(envAwsRegion, "")
}

s.Run("Failure - No wrapper key ID", func() {
t.Run("Failure - No wrapper key ID", func(t *testing.T) {
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}
oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId)
os.Setenv(EnvAwsKmsWrapperKeyId, "")
defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId)

_, err := s.wrapperWithMock.SetConfig(context.Background())
s.Require().Error(err, "expected error when AwsKms wrapping key ID is not provided")
_, err := wrapperWithMock.SetConfig(context.Background())
require.Error(t, err, "expected error when AwsKms wrapping key ID is not provided")
})

s.Run("Success - Test key ID pulled from environment variables", func() {
t.Run("Success - Test key ID pulled from environment variables", func(t *testing.T) {
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}
oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId)
os.Setenv(EnvAwsKmsWrapperKeyId, awsTestKeyId)
defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId)

_, err := s.wrapperWithMock.SetConfig(context.Background())
s.Require().NoError(err)
_, err := wrapperWithMock.SetConfig(context.Background())
require.NoError(t, err)
})

s.Run("Success - Ignore environment variables", func() {
t.Run("Success - Ignore environment variables", func(t *testing.T) {
// Setup environment values to ignore for the following values
for _, envVar := range []string{EnvAwsKmsWrapperKeyId, EnvVaultAwsKmsSealKeyId, EnvAwsKmsEndpoint, EnvAwsKmsEndpoint} {
oldVal := os.Getenv(envVar)
os.Setenv(envVar, "")
defer os.Setenv(envVar, oldVal)
}
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}

config := map[string]string{
"disallow_env_vars": "true",
Expand All @@ -79,83 +72,103 @@ func (s *AwsKmsSuite) TestSetConfig() {
"endpoint": "my-endpoint",
}

_, err := s.wrapperWithMock.SetConfig(context.Background(), wrapping.WithConfigMap(config))
s.Require().NoError(err)
_, err := wrapperWithMock.SetConfig(context.Background(), wrapping.WithConfigMap(config))
require.NoError(t, err)

s.Require().Equal(config["access_key"], s.wrapperWithMock.accessKey)
s.Require().Equal(config["secret_key"], s.wrapperWithMock.secretKey)
s.Require().Equal(config["kms_key_id"], s.wrapperWithMock.keyId)
s.Require().Equal(config["endpoint"], s.wrapperWithMock.endpoint)
require.Equal(t, config["access_key"], wrapperWithMock.accessKey)
require.Equal(t, config["secret_key"], wrapperWithMock.secretKey)
require.Equal(t, config["kms_key_id"], wrapperWithMock.keyId)
require.Equal(t, config["endpoint"], wrapperWithMock.endpoint)
})

s.Run("Success - endpoint set automatically", func() {
_, err := s.wrapperWithMock.SetConfig(s.T().Context(), WithKeyNotRequired(true))
s.Require().NoError(err)
t.Run("Success - endpoint set automatically", func(t *testing.T) {
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}
_, err := wrapperWithMock.SetConfig(t.Context(), WithKeyNotRequired(true))
require.NoError(t, err)

c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context())
s.Require().NoError(err)
s.Assert().Nil(c.Options().BaseEndpoint)
c, err := wrapperWithMock.GetAwsKmsClient(t.Context())
require.NoError(t, err)
require.Nil(t, c.Options().BaseEndpoint)
})

s.Run("Success - custom endpoint set from environment variables", func() {
t.Run("Success - custom endpoint set from environment variables", func(t *testing.T) {
expectedEndpoint := "https://example.com/0"
oldEndpoint := os.Getenv(EnvAwsKmsEndpoint)
os.Setenv(EnvAwsKmsEndpoint, expectedEndpoint)
defer os.Setenv(EnvAwsKmsEndpoint, oldEndpoint)
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}

_, err := s.wrapperWithMock.SetConfig(s.T().Context(), WithKeyNotRequired(true))
s.Require().NoError(err)
_, err := wrapperWithMock.SetConfig(t.Context(), WithKeyNotRequired(true))
require.NoError(t, err)

c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context())
s.Require().NoError(err)
s.Assert().Equal(expectedEndpoint, *(c.Options().BaseEndpoint))
c, err := wrapperWithMock.GetAwsKmsClient(t.Context())
require.NoError(t, err)
assert.Equal(t, expectedEndpoint, *(c.Options().BaseEndpoint))
})

s.Run("Success - custom endpoint set from config", func() {
t.Run("Success - custom endpoint set from config", func(t *testing.T) {
expectedEndpoint := "https://example.com/1"

cfg := map[string]string{
"endpoint": expectedEndpoint,
}
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}

_, err := s.wrapperWithMock.SetConfig(s.T().Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true))
s.Require().NoError(err)
_, err := wrapperWithMock.SetConfig(t.Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true))
require.NoError(t, err)

c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context())
s.Require().NoError(err)
s.Assert().Equal(expectedEndpoint, *(c.Options().BaseEndpoint))
c, err := wrapperWithMock.GetAwsKmsClient(t.Context())
require.NoError(t, err)
assert.Equal(t, expectedEndpoint, *(c.Options().BaseEndpoint))
})

s.Run("Success - custom endpoint set from environment variables taking precedence over config", func() {
t.Run("Success - custom endpoint set from environment variables taking precedence over config", func(t *testing.T) {
expectedEndpoint := "https://example.com/2"
oldEndpoint := os.Getenv(EnvAwsKmsEndpoint)
os.Setenv(EnvAwsKmsEndpoint, expectedEndpoint)
defer os.Setenv(EnvAwsKmsEndpoint, oldEndpoint)
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}

cfg := map[string]string{
"endpoint": "https://example.com/3",
}

_, err := s.wrapperWithMock.SetConfig(s.T().Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true))
s.Require().NoError(err)
_, err := wrapperWithMock.SetConfig(t.Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true))
require.NoError(t, err)

c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context())
s.Require().NoError(err)
s.Assert().Equal(expectedEndpoint, *(c.Options().BaseEndpoint))
c, err := wrapperWithMock.GetAwsKmsClient(t.Context())
require.NoError(t, err)
assert.Equal(t, expectedEndpoint, *(c.Options().BaseEndpoint))
})
}

func (s *AwsKmsSuite) TestEncryptAndDecrypt() {
s.Run("Success - mock client", func() {
func TestEncryptAndDecrypt(t *testing.T) {
t.Run("Success - mock client", func(t *testing.T) {
// Works around lack of AWS_REGION var in CI
if os.Getenv(envAwsRegion) == "" {
os.Setenv(envAwsRegion, "us-west-2")
defer os.Setenv(envAwsRegion, "")
}
wrapperWithMock := NewWrapper()
wrapperWithMock.client = &mockClient{
keyId: awsTestKeyId,
}
oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId)
os.Setenv(EnvAwsKmsWrapperKeyId, awsTestKeyId)
defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId)
encryptionRoundTrip(s, s.wrapperWithMock)
encryptionRoundTrip(t, wrapperWithMock)
})
// To run the concrete enryption test, the following env variables need to be set:
// - AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID
Expand All @@ -166,25 +179,25 @@ func (s *AwsKmsSuite) TestEncryptAndDecrypt() {
// - AWS_REGION
// - Works around https://hashicorp.atlassian.net/browse/ICU-17849

s.Run("Success - concrete client", func() {
t.Run("Success - concrete client", func(t *testing.T) {
if os.Getenv(EnvAwsKmsWrapperKeyId) == "" && os.Getenv(EnvVaultAwsKmsSealKeyId) == "" {
s.T().Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for concrete encryption test")
t.Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for concrete encryption test")
}
if os.Getenv("AWS_ACCESS_KEY_ID") == "" {
s.T().Skip("AWS_ACCESS_KEY_ID required for concrete encryption test")
t.Skip("AWS_ACCESS_KEY_ID required for concrete encryption test")
}
if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" {
s.T().Skip("AWS_SECRET_ACCESS_KEY required for concrete encryption test")
t.Skip("AWS_SECRET_ACCESS_KEY required for concrete encryption test")
}
if os.Getenv("AWS_SESSION_TOKEN") == "" {
s.T().Skip("AWS_SESSION_TOKEN required for concrete encryption test")
t.Skip("AWS_SESSION_TOKEN required for concrete encryption test")
}
if os.Getenv("AWS_REGION") == "" {
s.T().Skip("AWS_REGION required for concrete encryption test")
t.Skip("AWS_REGION required for concrete encryption test")
}
w := NewWrapper()

encryptionRoundTrip(s, w)
encryptionRoundTrip(t, w)
})
}

Expand All @@ -197,18 +210,18 @@ func (s *AwsKmsSuite) TestEncryptAndDecrypt() {
// - Set AWS_PROFILE=$SINK_PROFILE (for shared profile through AWS_PROFILE)
// - Set TEST_PROFILE=$SINK_PROFILE (for shared profile through WithSharedCredsProfile)
// - Set AWS_REGION and AWS_KMS_WRAPPER_KEY_ID as above
func (s *AwsKmsSuite) TestSharedProfiles() {
func TestSharedProfiles(t *testing.T) {
if os.Getenv("AWS_REGION") == "" {
s.T().Skip("AWS_REGION required for shared profiles tests")
t.Skip("AWS_REGION required for shared profiles tests")
}
if os.Getenv(EnvAwsKmsWrapperKeyId) == "" && os.Getenv(EnvVaultAwsKmsSealKeyId) == "" {
s.T().Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for shared profiles tests")
t.Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for shared profiles tests")
}

s.Run("Success - shared profile from WithSharedCredsProfile", func() {
t.Run("Success - shared profile from WithSharedCredsProfile", func(t *testing.T) {
prof := os.Getenv("TEST_PROFILE")
if prof == "" {
s.T().Skip("TEST_PROFILE required for shared profile from WithSharedCredsProfile test")
t.Skip("TEST_PROFILE required for shared profile from WithSharedCredsProfile test")
}
// Prevent AWS_PROFILE from clobbering this test
if old := os.Getenv(envAwsProfile); old != "" {
Expand All @@ -218,47 +231,46 @@ func (s *AwsKmsSuite) TestSharedProfiles() {

w := NewWrapper()

_, err := w.SetConfig(s.T().Context(), WithSharedCredsProfile(prof))
s.Require().NoError(err)
_, err := w.SetConfig(t.Context(), WithSharedCredsProfile(prof))
require.NoError(t, err)

encryptionRoundTrip(s, w)
encryptionRoundTrip(t, w)
})

s.Run("Success - shared profile from AWS_PROFILE", func() {
t.Run("Success - shared profile from AWS_PROFILE", func(t *testing.T) {
// Default awskms config pulls shared creds from AWS_PROFILE if it's set
if os.Getenv(envAwsProfile) == "" {
s.T().Skip("AWS_PROFILE required for shared profile from AWS_PROFILE test")
t.Skip("AWS_PROFILE required for shared profile from AWS_PROFILE test")
}

w := NewWrapper()
_, err := w.SetConfig(s.T().Context())
s.Require().NoError(err)
encryptionRoundTrip(s, w)
_, err := w.SetConfig(t.Context())
require.NoError(t, err)
encryptionRoundTrip(t, w)
})

s.Run("Failure - no shared config", func() {
t.Run("Failure - no shared config", func(t *testing.T) {
if old := os.Getenv(envAwsProfile); old != "" {
os.Setenv(envAwsProfile, "")
defer os.Setenv(envAwsProfile, old)
}

w := NewWrapper()

_, err := w.SetConfig(s.T().Context(), WithSharedCredsProfile("this-profile-definitely-doesn't-exist"))
s.Require().Error(err)
_, err := w.SetConfig(t.Context(), WithSharedCredsProfile("this-profile-definitely-doesn't-exist"))
require.Error(t, err)
})

}

func encryptionRoundTrip(s *AwsKmsSuite, w *Wrapper) {
_, err := w.SetConfig(s.T().Context())
s.Require().NoError(err)
func encryptionRoundTrip(t *testing.T, w *Wrapper) {
_, err := w.SetConfig(t.Context())
require.NoError(t, err)

expected := []byte("foo")
swi, err := w.Encrypt(context.Background(), expected, nil)
s.Require().NoError(err)
require.NoError(t, err)

output, err := w.Decrypt(context.Background(), swi, nil)
s.Require().NoError(err)
s.Assert().Equal(expected, output)
require.NoError(t, err)
assert.Equal(t, expected, output)
}