diff --git a/wrappers/awskms/awskms_test.go b/wrappers/awskms/awskms_test.go index c62043c7..05b2e3cb 100644 --- a/wrappers/awskms/awskms_test.go +++ b/wrappers/awskms/awskms_test.go @@ -9,9 +9,8 @@ import ( "testing" wrapping "github.com/hashicorp/go-kms-wrapping/v2" - "github.com/stretchr/testify/suite" - - "go.uber.org/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -20,56 +19,50 @@ const ( envAwsRegion = "AWS_REGION" ) -type AwsKmsSuite struct { - suite.Suite - ctrl *gomock.Controller - wrapperWithMock *Wrapper -} - -func TestSuite(t *testing.T) { - suite.Run(t, new(AwsKmsSuite)) -} - -func (s *AwsKmsSuite) SetupSubTest() { - s.ctrl = gomock.NewController(s.T()) - s.wrapperWithMock = NewWrapper() - s.wrapperWithMock.client = &mockClient{ - keyId: awsTestKeyId, - } -} - -func (s *AwsKmsSuite) TestSetConfig() { +func TestSetConfig(t *testing.T) { // Works around lack of AWS_REGION var in CI if os.Getenv(envAwsRegion) == "" { os.Setenv(envAwsRegion, "us-west-2") defer os.Setenv(envAwsRegion, "") } - s.Run("Failure - No wrapper key ID", func() { + t.Run("Failure - No wrapper key ID", func(t *testing.T) { + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId) os.Setenv(EnvAwsKmsWrapperKeyId, "") defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId) - _, err := s.wrapperWithMock.SetConfig(context.Background()) - s.Require().Error(err, "expected error when AwsKms wrapping key ID is not provided") + _, err := wrapperWithMock.SetConfig(context.Background()) + require.Error(t, err, "expected error when AwsKms wrapping key ID is not provided") }) - s.Run("Success - Test key ID pulled from environment variables", func() { + t.Run("Success - Test key ID pulled from environment variables", func(t *testing.T) { + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId) os.Setenv(EnvAwsKmsWrapperKeyId, awsTestKeyId) defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId) - _, err := s.wrapperWithMock.SetConfig(context.Background()) - s.Require().NoError(err) + _, err := wrapperWithMock.SetConfig(context.Background()) + require.NoError(t, err) }) - s.Run("Success - Ignore environment variables", func() { + t.Run("Success - Ignore environment variables", func(t *testing.T) { // Setup environment values to ignore for the following values for _, envVar := range []string{EnvAwsKmsWrapperKeyId, EnvVaultAwsKmsSealKeyId, EnvAwsKmsEndpoint, EnvAwsKmsEndpoint} { oldVal := os.Getenv(envVar) os.Setenv(envVar, "") defer os.Setenv(envVar, oldVal) } + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } config := map[string]string{ "disallow_env_vars": "true", @@ -79,83 +72,103 @@ func (s *AwsKmsSuite) TestSetConfig() { "endpoint": "my-endpoint", } - _, err := s.wrapperWithMock.SetConfig(context.Background(), wrapping.WithConfigMap(config)) - s.Require().NoError(err) + _, err := wrapperWithMock.SetConfig(context.Background(), wrapping.WithConfigMap(config)) + require.NoError(t, err) - s.Require().Equal(config["access_key"], s.wrapperWithMock.accessKey) - s.Require().Equal(config["secret_key"], s.wrapperWithMock.secretKey) - s.Require().Equal(config["kms_key_id"], s.wrapperWithMock.keyId) - s.Require().Equal(config["endpoint"], s.wrapperWithMock.endpoint) + require.Equal(t, config["access_key"], wrapperWithMock.accessKey) + require.Equal(t, config["secret_key"], wrapperWithMock.secretKey) + require.Equal(t, config["kms_key_id"], wrapperWithMock.keyId) + require.Equal(t, config["endpoint"], wrapperWithMock.endpoint) }) - s.Run("Success - endpoint set automatically", func() { - _, err := s.wrapperWithMock.SetConfig(s.T().Context(), WithKeyNotRequired(true)) - s.Require().NoError(err) + t.Run("Success - endpoint set automatically", func(t *testing.T) { + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } + _, err := wrapperWithMock.SetConfig(t.Context(), WithKeyNotRequired(true)) + require.NoError(t, err) - c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context()) - s.Require().NoError(err) - s.Assert().Nil(c.Options().BaseEndpoint) + c, err := wrapperWithMock.GetAwsKmsClient(t.Context()) + require.NoError(t, err) + require.Nil(t, c.Options().BaseEndpoint) }) - s.Run("Success - custom endpoint set from environment variables", func() { + t.Run("Success - custom endpoint set from environment variables", func(t *testing.T) { expectedEndpoint := "https://example.com/0" oldEndpoint := os.Getenv(EnvAwsKmsEndpoint) os.Setenv(EnvAwsKmsEndpoint, expectedEndpoint) defer os.Setenv(EnvAwsKmsEndpoint, oldEndpoint) + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } - _, err := s.wrapperWithMock.SetConfig(s.T().Context(), WithKeyNotRequired(true)) - s.Require().NoError(err) + _, err := wrapperWithMock.SetConfig(t.Context(), WithKeyNotRequired(true)) + require.NoError(t, err) - c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context()) - s.Require().NoError(err) - s.Assert().Equal(expectedEndpoint, *(c.Options().BaseEndpoint)) + c, err := wrapperWithMock.GetAwsKmsClient(t.Context()) + require.NoError(t, err) + assert.Equal(t, expectedEndpoint, *(c.Options().BaseEndpoint)) }) - s.Run("Success - custom endpoint set from config", func() { + t.Run("Success - custom endpoint set from config", func(t *testing.T) { expectedEndpoint := "https://example.com/1" cfg := map[string]string{ "endpoint": expectedEndpoint, } + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } - _, err := s.wrapperWithMock.SetConfig(s.T().Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true)) - s.Require().NoError(err) + _, err := wrapperWithMock.SetConfig(t.Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true)) + require.NoError(t, err) - c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context()) - s.Require().NoError(err) - s.Assert().Equal(expectedEndpoint, *(c.Options().BaseEndpoint)) + c, err := wrapperWithMock.GetAwsKmsClient(t.Context()) + require.NoError(t, err) + assert.Equal(t, expectedEndpoint, *(c.Options().BaseEndpoint)) }) - s.Run("Success - custom endpoint set from environment variables taking precedence over config", func() { + t.Run("Success - custom endpoint set from environment variables taking precedence over config", func(t *testing.T) { expectedEndpoint := "https://example.com/2" oldEndpoint := os.Getenv(EnvAwsKmsEndpoint) os.Setenv(EnvAwsKmsEndpoint, expectedEndpoint) defer os.Setenv(EnvAwsKmsEndpoint, oldEndpoint) + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } cfg := map[string]string{ "endpoint": "https://example.com/3", } - _, err := s.wrapperWithMock.SetConfig(s.T().Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true)) - s.Require().NoError(err) + _, err := wrapperWithMock.SetConfig(t.Context(), wrapping.WithConfigMap(cfg), WithKeyNotRequired(true)) + require.NoError(t, err) - c, err := s.wrapperWithMock.GetAwsKmsClient(s.T().Context()) - s.Require().NoError(err) - s.Assert().Equal(expectedEndpoint, *(c.Options().BaseEndpoint)) + c, err := wrapperWithMock.GetAwsKmsClient(t.Context()) + require.NoError(t, err) + assert.Equal(t, expectedEndpoint, *(c.Options().BaseEndpoint)) }) } -func (s *AwsKmsSuite) TestEncryptAndDecrypt() { - s.Run("Success - mock client", func() { +func TestEncryptAndDecrypt(t *testing.T) { + t.Run("Success - mock client", func(t *testing.T) { // Works around lack of AWS_REGION var in CI if os.Getenv(envAwsRegion) == "" { os.Setenv(envAwsRegion, "us-west-2") defer os.Setenv(envAwsRegion, "") } + wrapperWithMock := NewWrapper() + wrapperWithMock.client = &mockClient{ + keyId: awsTestKeyId, + } oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId) os.Setenv(EnvAwsKmsWrapperKeyId, awsTestKeyId) defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId) - encryptionRoundTrip(s, s.wrapperWithMock) + encryptionRoundTrip(t, wrapperWithMock) }) // To run the concrete enryption test, the following env variables need to be set: // - AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID @@ -166,25 +179,25 @@ func (s *AwsKmsSuite) TestEncryptAndDecrypt() { // - AWS_REGION // - Works around https://hashicorp.atlassian.net/browse/ICU-17849 - s.Run("Success - concrete client", func() { + t.Run("Success - concrete client", func(t *testing.T) { if os.Getenv(EnvAwsKmsWrapperKeyId) == "" && os.Getenv(EnvVaultAwsKmsSealKeyId) == "" { - s.T().Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for concrete encryption test") + t.Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for concrete encryption test") } if os.Getenv("AWS_ACCESS_KEY_ID") == "" { - s.T().Skip("AWS_ACCESS_KEY_ID required for concrete encryption test") + t.Skip("AWS_ACCESS_KEY_ID required for concrete encryption test") } if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" { - s.T().Skip("AWS_SECRET_ACCESS_KEY required for concrete encryption test") + t.Skip("AWS_SECRET_ACCESS_KEY required for concrete encryption test") } if os.Getenv("AWS_SESSION_TOKEN") == "" { - s.T().Skip("AWS_SESSION_TOKEN required for concrete encryption test") + t.Skip("AWS_SESSION_TOKEN required for concrete encryption test") } if os.Getenv("AWS_REGION") == "" { - s.T().Skip("AWS_REGION required for concrete encryption test") + t.Skip("AWS_REGION required for concrete encryption test") } w := NewWrapper() - encryptionRoundTrip(s, w) + encryptionRoundTrip(t, w) }) } @@ -197,18 +210,18 @@ func (s *AwsKmsSuite) TestEncryptAndDecrypt() { // - Set AWS_PROFILE=$SINK_PROFILE (for shared profile through AWS_PROFILE) // - Set TEST_PROFILE=$SINK_PROFILE (for shared profile through WithSharedCredsProfile) // - Set AWS_REGION and AWS_KMS_WRAPPER_KEY_ID as above -func (s *AwsKmsSuite) TestSharedProfiles() { +func TestSharedProfiles(t *testing.T) { if os.Getenv("AWS_REGION") == "" { - s.T().Skip("AWS_REGION required for shared profiles tests") + t.Skip("AWS_REGION required for shared profiles tests") } if os.Getenv(EnvAwsKmsWrapperKeyId) == "" && os.Getenv(EnvVaultAwsKmsSealKeyId) == "" { - s.T().Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for shared profiles tests") + t.Skip("AWSKMS_WRAPPER_KEY_ID or VAULT_AWSKMS_SEAL_KEY_ID required for shared profiles tests") } - s.Run("Success - shared profile from WithSharedCredsProfile", func() { + t.Run("Success - shared profile from WithSharedCredsProfile", func(t *testing.T) { prof := os.Getenv("TEST_PROFILE") if prof == "" { - s.T().Skip("TEST_PROFILE required for shared profile from WithSharedCredsProfile test") + t.Skip("TEST_PROFILE required for shared profile from WithSharedCredsProfile test") } // Prevent AWS_PROFILE from clobbering this test if old := os.Getenv(envAwsProfile); old != "" { @@ -218,25 +231,25 @@ func (s *AwsKmsSuite) TestSharedProfiles() { w := NewWrapper() - _, err := w.SetConfig(s.T().Context(), WithSharedCredsProfile(prof)) - s.Require().NoError(err) + _, err := w.SetConfig(t.Context(), WithSharedCredsProfile(prof)) + require.NoError(t, err) - encryptionRoundTrip(s, w) + encryptionRoundTrip(t, w) }) - s.Run("Success - shared profile from AWS_PROFILE", func() { + t.Run("Success - shared profile from AWS_PROFILE", func(t *testing.T) { // Default awskms config pulls shared creds from AWS_PROFILE if it's set if os.Getenv(envAwsProfile) == "" { - s.T().Skip("AWS_PROFILE required for shared profile from AWS_PROFILE test") + t.Skip("AWS_PROFILE required for shared profile from AWS_PROFILE test") } w := NewWrapper() - _, err := w.SetConfig(s.T().Context()) - s.Require().NoError(err) - encryptionRoundTrip(s, w) + _, err := w.SetConfig(t.Context()) + require.NoError(t, err) + encryptionRoundTrip(t, w) }) - s.Run("Failure - no shared config", func() { + t.Run("Failure - no shared config", func(t *testing.T) { if old := os.Getenv(envAwsProfile); old != "" { os.Setenv(envAwsProfile, "") defer os.Setenv(envAwsProfile, old) @@ -244,21 +257,20 @@ func (s *AwsKmsSuite) TestSharedProfiles() { w := NewWrapper() - _, err := w.SetConfig(s.T().Context(), WithSharedCredsProfile("this-profile-definitely-doesn't-exist")) - s.Require().Error(err) + _, err := w.SetConfig(t.Context(), WithSharedCredsProfile("this-profile-definitely-doesn't-exist")) + require.Error(t, err) }) - } -func encryptionRoundTrip(s *AwsKmsSuite, w *Wrapper) { - _, err := w.SetConfig(s.T().Context()) - s.Require().NoError(err) +func encryptionRoundTrip(t *testing.T, w *Wrapper) { + _, err := w.SetConfig(t.Context()) + require.NoError(t, err) expected := []byte("foo") swi, err := w.Encrypt(context.Background(), expected, nil) - s.Require().NoError(err) + require.NoError(t, err) output, err := w.Decrypt(context.Background(), swi, nil) - s.Require().NoError(err) - s.Assert().Equal(expected, output) + require.NoError(t, err) + assert.Equal(t, expected, output) }