From 2766c98e1d526deacbb49ee8fafa818c4942fec1 Mon Sep 17 00:00:00 2001 From: Timothy Yen Date: Tue, 28 Jun 2022 17:52:32 -0700 Subject: [PATCH] Update Cache.Get* methods to accept a context --- README.md | 8 +++---- scintegtests/integration_test.go | 20 ++++++++-------- secretcache/cache.go | 16 ++++++------- secretcache/cacheHook_test.go | 5 ++-- secretcache/cacheItem.go | 14 +++++------ secretcache/cacheObjects_test.go | 4 ++-- secretcache/cacheVersion.go | 12 +++++----- secretcache/cache_test.go | 41 ++++++++++++++++---------------- 8 files changed, 61 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 2152907..194a973 100644 --- a/README.md +++ b/README.md @@ -33,15 +33,15 @@ package main import ( "github.com/aws/aws-lambda-go/lambda" - "github.com/aws/aws-secretsmanager-caching-go/secretcache" + "github.com/aws/aws-secretsmanager-caching-go/v2/secretcache" ) var( secretCache, _ = secretcache.New() ) -func HandleRequest(secretId string) string { - result, _ := secretCache.GetSecretString(secretId) +func HandleRequest(ctx context.Context, secretId string) string { + result, _ := secretCache.GetSecretString(ctx, secretId) // Use secret to connect to secured resource. return "Success" } @@ -85,4 +85,4 @@ We use GitHub issues for tracking bugs and caching library feature requests and ## License -This library is licensed under the Apache 2.0 License. \ No newline at end of file +This library is licensed under the Apache 2.0 License. diff --git a/scintegtests/integration_test.go b/scintegtests/integration_test.go index 6189554..2e318c7 100644 --- a/scintegtests/integration_test.go +++ b/scintegtests/integration_test.go @@ -160,7 +160,7 @@ func integTest_getSecretBinary(t *testing.T, api *secretsmanager.Client) string return "" } - resultBinary, err := cache.GetSecretBinary(*createResult.ARN) + resultBinary, err := cache.GetSecretBinary(context.Background(), *createResult.ARN) if err != nil { t.Error(err) @@ -200,7 +200,7 @@ func integTest_getSecretBinaryWithStage(t *testing.T, api *secretsmanager.Client return *createResult.ARN } - resultBinary, err := cache.GetSecretBinaryWithStage(*createResult.ARN, "AWSPREVIOUS") + resultBinary, err := cache.GetSecretBinaryWithStage(context.Background(), *createResult.ARN, "AWSPREVIOUS") if err != nil { t.Error(err) @@ -211,7 +211,7 @@ func integTest_getSecretBinaryWithStage(t *testing.T, api *secretsmanager.Client t.Error("Expected and result binary not the same") } - resultBinary, err = cache.GetSecretBinaryWithStage(*createResult.ARN, "AWSCURRENT") + resultBinary, err = cache.GetSecretBinaryWithStage(context.Background(), *createResult.ARN, "AWSCURRENT") if err != nil { t.Error(err) @@ -237,7 +237,7 @@ func integTest_getSecretString(t *testing.T, api *secretsmanager.Client) string return "" } - resultString, err := cache.GetSecretString(*createResult.ARN) + resultString, err := cache.GetSecretString(context.Background(), *createResult.ARN) if err != nil { t.Error(err) @@ -277,7 +277,7 @@ func integTest_getSecretStringWithStage(t *testing.T, api *secretsmanager.Client return *createResult.ARN } - resultString, err := cache.GetSecretStringWithStage(*createResult.ARN, "AWSPREVIOUS") + resultString, err := cache.GetSecretStringWithStage(context.Background(), *createResult.ARN, "AWSPREVIOUS") if err != nil { t.Error(err) @@ -288,7 +288,7 @@ func integTest_getSecretStringWithStage(t *testing.T, api *secretsmanager.Client t.Errorf("Expected and result secret string are different - \"%s\", \"%s\"", secretString, resultString) } - resultString, err = cache.GetSecretStringWithStage(*createResult.ARN, "AWSCURRENT") + resultString, err = cache.GetSecretStringWithStage(context.Background(), *createResult.ARN, "AWSCURRENT") if err != nil { t.Error(err) @@ -317,7 +317,7 @@ func integTest_getSecretStringWithTTL(t *testing.T, api *secretsmanager.Client) return "" } - resultString, err := cache.GetSecretString(*createResult.ARN) + resultString, err := cache.GetSecretString(context.Background(), *createResult.ARN) if err != nil { t.Error(err) @@ -342,7 +342,7 @@ func integTest_getSecretStringWithTTL(t *testing.T, api *secretsmanager.Client) return *createResult.ARN } - resultString, err = cache.GetSecretString(*createResult.ARN) + resultString, err = cache.GetSecretString(context.Background(), *createResult.ARN) if err != nil { t.Error(err) @@ -356,7 +356,7 @@ func integTest_getSecretStringWithTTL(t *testing.T, api *secretsmanager.Client) time.Sleep(time.Nanosecond * time.Duration(ttlNanoSeconds)) - resultString, err = cache.GetSecretString(*createResult.ARN) + resultString, err = cache.GetSecretString(context.Background(), *createResult.ARN) if updatedSecretString != resultString { t.Errorf("Expected cached secret to be same as updated version - \"%s\", \"%s\"", resultString, updatedSecretString) return *createResult.ARN @@ -371,7 +371,7 @@ func integTest_getSecretStringNoSecret(t *testing.T, api *secretsmanager.Client) ) secretName := "NoSuchSecret" - _, err := cache.GetSecretString(secretName) + _, err := cache.GetSecretString(context.Background(), secretName) var rnfe *types.ResourceNotFoundException diff --git a/secretcache/cache.go b/secretcache/cache.go index 09ee4d6..9ee053f 100644 --- a/secretcache/cache.go +++ b/secretcache/cache.go @@ -88,16 +88,16 @@ func (c *Cache) getCachedSecret(secretId string) *secretCacheItem { // GetSecretString gets the secret string value from the cache for given secret id and a default version stage. // Returns the secret sting and an error if operation failed. -func (c *Cache) GetSecretString(secretId string) (string, error) { - return c.GetSecretStringWithStage(secretId, DefaultVersionStage) +func (c *Cache) GetSecretString(ctx context.Context, secretId string) (string, error) { + return c.GetSecretStringWithStage(ctx, secretId, DefaultVersionStage) } // GetSecretStringWithStage gets the secret string value from the cache for given secret id and version stage. // Returns the secret sting and an error if operation failed. -func (c *Cache) GetSecretStringWithStage(secretId string, versionStage string) (string, error) { +func (c *Cache) GetSecretStringWithStage(ctx context.Context, secretId string, versionStage string) (string, error) { secretCacheItem := c.getCachedSecret(secretId) - getSecretValueOutput, err := secretCacheItem.getSecretValue(versionStage) + getSecretValueOutput, err := secretCacheItem.getSecretValue(ctx, versionStage) if err != nil { return "", err @@ -116,16 +116,16 @@ func (c *Cache) GetSecretStringWithStage(secretId string, versionStage string) ( // GetSecretBinary gets the secret binary value from the cache for given secret id and a default version stage. // Returns the secret binary and an error if operation failed. -func (c *Cache) GetSecretBinary(secretId string) ([]byte, error) { - return c.GetSecretBinaryWithStage(secretId, DefaultVersionStage) +func (c *Cache) GetSecretBinary(ctx context.Context, secretId string) ([]byte, error) { + return c.GetSecretBinaryWithStage(ctx, secretId, DefaultVersionStage) } // GetSecretBinaryWithStage gets the secret binary value from the cache for given secret id and version stage. // Returns the secret binary and an error if operation failed. -func (c *Cache) GetSecretBinaryWithStage(secretId string, versionStage string) ([]byte, error) { +func (c *Cache) GetSecretBinaryWithStage(ctx context.Context, secretId string, versionStage string) ([]byte, error) { secretCacheItem := c.getCachedSecret(secretId) - getSecretValueOutput, err := secretCacheItem.getSecretValue(versionStage) + getSecretValueOutput, err := secretCacheItem.getSecretValue(ctx, versionStage) if err != nil { return nil, err diff --git a/secretcache/cacheHook_test.go b/secretcache/cacheHook_test.go index f3521f0..5ebd289 100644 --- a/secretcache/cacheHook_test.go +++ b/secretcache/cacheHook_test.go @@ -15,6 +15,7 @@ package secretcache_test import ( "bytes" + "context" "testing" "github.com/aws/aws-secretsmanager-caching-go/v2/secretcache" @@ -44,7 +45,7 @@ func TestCacheHookString(t *testing.T) { func(c *secretcache.Cache) { c.CacheConfig.Hook = hook }, ) - result, err := secretCache.GetSecretString(secretId) + result, err := secretCache.GetSecretString(context.Background(), secretId) if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) @@ -75,7 +76,7 @@ func TestCacheHookBinary(t *testing.T) { func(c *secretcache.Cache) { c.CacheConfig.Hook = hook }, ) - result, err := secretCache.GetSecretBinary(secretId) + result, err := secretCache.GetSecretBinary(context.Background(), secretId) if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) diff --git a/secretcache/cacheItem.go b/secretcache/cacheItem.go index f7d7d5d..fdba1c8 100644 --- a/secretcache/cacheItem.go +++ b/secretcache/cacheItem.go @@ -78,12 +78,12 @@ func (ci *secretCacheItem) getVersionId(versionStage string) (string, bool) { // executeRefresh performs the actual refresh of the cached secret information. // Returns the DescribeSecret API result and an error if call failed. -func (ci *secretCacheItem) executeRefresh() (*secretsmanager.DescribeSecretOutput, error) { +func (ci *secretCacheItem) executeRefresh(ctx context.Context) (*secretsmanager.DescribeSecretOutput, error) { input := &secretsmanager.DescribeSecretInput{ SecretId: &ci.secretId, } - result, err := ci.client.DescribeSecret(context.Background(), input, addUserAgent) + result, err := ci.client.DescribeSecret(ctx, input, addUserAgent) var maxTTL int64 if ci.config.CacheItemTTL == 0 { @@ -130,14 +130,14 @@ func (ci *secretCacheItem) getVersion(versionStage string) (*cacheVersion, bool) } // refresh the cached object when needed. -func (ci *secretCacheItem) refresh() { +func (ci *secretCacheItem) refresh(ctx context.Context) { if !ci.isRefreshNeeded() { return } ci.refreshNeeded = false - result, err := ci.executeRefresh() + result, err := ci.executeRefresh(ctx) if err != nil { ci.errorCount++ @@ -156,7 +156,7 @@ func (ci *secretCacheItem) refresh() { // getSecretValue gets the cached secret value for the given version stage. // Returns the GetSecretValue API result and an error if operation fails. -func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.GetSecretValueOutput, error) { +func (ci *secretCacheItem) getSecretValue(ctx context.Context, versionStage string) (*secretsmanager.GetSecretValueOutput, error) { if versionStage == "" && ci.config.VersionStage == "" { versionStage = DefaultVersionStage } else if versionStage == "" && ci.config.VersionStage != "" { @@ -166,7 +166,7 @@ func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager. ci.mux.Lock() defer ci.mux.Unlock() - ci.refresh() + ci.refresh(ctx) version, ok := ci.getVersion(versionStage) if !ok { @@ -181,7 +181,7 @@ func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager. } } - return version.getSecretValue() + return version.getSecretValue(ctx) } // setWithHook sets the cache item's data using the CacheHook, if one is configured. diff --git a/secretcache/cacheObjects_test.go b/secretcache/cacheObjects_test.go index 921f59d..5c55a33 100644 --- a/secretcache/cacheObjects_test.go +++ b/secretcache/cacheObjects_test.go @@ -72,7 +72,7 @@ func TestMaxCacheTTL(t *testing.T) { config := CacheConfig{CacheItemTTL: -1} cacheItem.config = config - _, err := cacheItem.executeRefresh() + _, err := cacheItem.executeRefresh(context.Background()) if err == nil { t.Fatalf("Expected error due to negative cache ttl") @@ -81,7 +81,7 @@ func TestMaxCacheTTL(t *testing.T) { config = CacheConfig{CacheItemTTL: 0} cacheItem.config = config - _, err = cacheItem.executeRefresh() + _, err = cacheItem.executeRefresh(context.Background()) if err != nil { t.Fatalf("Unexpected error on zero cache ttl") diff --git a/secretcache/cacheVersion.go b/secretcache/cacheVersion.go index 8eda676..d113e55 100644 --- a/secretcache/cacheVersion.go +++ b/secretcache/cacheVersion.go @@ -43,14 +43,14 @@ func (cv *cacheVersion) isRefreshNeeded() bool { } // refresh the cached object when needed. -func (cv *cacheVersion) refresh() { +func (cv *cacheVersion) refresh(ctx context.Context) { if !cv.isRefreshNeeded() { return } cv.refreshNeeded = false - result, err := cv.executeRefresh() + result, err := cv.executeRefresh(ctx) if err != nil { cv.errorCount++ @@ -70,21 +70,21 @@ func (cv *cacheVersion) refresh() { // executeRefresh performs the actual refresh of the cached secret information. // Returns the GetSecretValue API result and an error if operation fails. -func (cv *cacheVersion) executeRefresh() (*secretsmanager.GetSecretValueOutput, error) { +func (cv *cacheVersion) executeRefresh(ctx context.Context) (*secretsmanager.GetSecretValueOutput, error) { input := &secretsmanager.GetSecretValueInput{ SecretId: &cv.secretId, VersionId: &cv.versionId, } - return cv.client.GetSecretValue(context.Background(), input, addUserAgent) + return cv.client.GetSecretValue(ctx, input, addUserAgent) } // getSecretValue gets the cached secret version value. // Returns the GetSecretValue API cached result and an error if operation fails. -func (cv *cacheVersion) getSecretValue() (*secretsmanager.GetSecretValueOutput, error) { +func (cv *cacheVersion) getSecretValue(ctx context.Context) (*secretsmanager.GetSecretValueOutput, error) { cv.mux.Lock() defer cv.mux.Unlock() - cv.refresh() + cv.refresh(ctx) return cv.getWithHook(), cv.err } diff --git a/secretcache/cache_test.go b/secretcache/cache_test.go index 24358e4..d87176e 100644 --- a/secretcache/cache_test.go +++ b/secretcache/cache_test.go @@ -15,6 +15,7 @@ package secretcache_test import ( "bytes" + "context" "errors" "testing" @@ -36,7 +37,7 @@ func TestGetSecretString(t *testing.T) { secretCache, _ := secretcache.New( func(c *secretcache.Cache) { c.Client = &mockClient }, ) - result, err := secretCache.GetSecretString("test") + result, err := secretCache.GetSecretString(context.Background(), "test") if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) @@ -55,7 +56,7 @@ func TestGetSecretBinary(t *testing.T) { secretCache, _ := secretcache.New( func(c *secretcache.Cache) { c.Client = &mockClient }, ) - result, err := secretCache.GetSecretBinary("test") + result, err := secretCache.GetSecretBinary(context.Background(), "test") if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) @@ -79,13 +80,13 @@ func TestGetSecretMissing(t *testing.T) { func(c *secretcache.Cache) { c.Client = &mockClient }, ) - _, err := secretCache.GetSecretString("test") + _, err := secretCache.GetSecretString(context.Background(), "test") if err == nil { t.Fatalf("Expected to not find a SecretString in this version") } - _, err = secretCache.GetSecretBinary("test") + _, err = secretCache.GetSecretBinary(context.Background(), "test") if err == nil { t.Fatalf("Expected to not find a SecretString in this version") @@ -109,7 +110,7 @@ func TestGetSecretNoCurrent(t *testing.T) { func(c *secretcache.Cache) { c.Client = &mockClient }, ) - _, err := secretCache.GetSecretString("test") + _, err := secretCache.GetSecretString(context.Background(), "test") if err == nil { t.Fatalf("Expected to not find secret version") @@ -118,7 +119,7 @@ func TestGetSecretNoCurrent(t *testing.T) { mockClient.MockedGetResult.SecretString = nil mockClient.MockedGetResult.SecretBinary = []byte{0, 1, 0, 1, 0, 1, 0, 1} - _, err = secretCache.GetSecretBinary("test") + _, err = secretCache.GetSecretBinary(context.Background(), "test") if err == nil { t.Fatalf("Expected to not find secret version") @@ -135,13 +136,13 @@ func TestGetSecretVersionNotFound(t *testing.T) { func(c *secretcache.Cache) { c.Client = &mockClient }, ) - _, err := secretCache.GetSecretString(secretId) + _, err := secretCache.GetSecretString(context.Background(), secretId) if err == nil { t.Fatalf("Expected to not find secret version") } - _, err = secretCache.GetSecretBinary(secretId) + _, err = secretCache.GetSecretBinary(context.Background(), secretId) if err == nil { t.Fatalf("Expected to not find secret version") @@ -158,13 +159,13 @@ func TestGetSecretNoVersions(t *testing.T) { func(c *secretcache.Cache) { c.Client = &mockClient }, ) - _, err := secretCache.GetSecretString(secretId) + _, err := secretCache.GetSecretString(context.Background(), secretId) if err == nil { t.Fatalf("Expected to not find secret version") } - _, err = secretCache.GetSecretBinary(secretId) + _, err = secretCache.GetSecretBinary(context.Background(), secretId) if err == nil { t.Fatalf("Expected to not find secret version") @@ -178,7 +179,7 @@ func TestGetSecretStringMultipleTimes(t *testing.T) { ) for i := 0; i < 100; i++ { - result, err := secretCache.GetSecretString(secretId) + result, err := secretCache.GetSecretString(context.Background(), secretId) if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) } @@ -208,7 +209,7 @@ func TestGetSecretBinaryMultipleTimes(t *testing.T) { ) for i := 0; i < 100; i++ { - result, err := secretCache.GetSecretBinary(secretId) + result, err := secretCache.GetSecretBinary(context.Background(), secretId) if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) } @@ -236,7 +237,7 @@ func TestGetSecretStringRefresh(t *testing.T) { ) for i := 0; i < 10; i++ { - result, err := secretCache.GetSecretString(secretId) + result, err := secretCache.GetSecretString(context.Background(), secretId) if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) } @@ -259,7 +260,7 @@ func TestGetSecretBinaryRefresh(t *testing.T) { ) for i := 0; i < 10; i++ { - result, err := secretCache.GetSecretBinary(secretId) + result, err := secretCache.GetSecretBinary(context.Background(), secretId) if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) } @@ -278,7 +279,7 @@ func TestGetSecretStringWithStage(t *testing.T) { ) for i := 0; i < 10; i++ { - result, err := secretCache.GetSecretStringWithStage(secretId, "versionStage-42") + result, err := secretCache.GetSecretStringWithStage(context.Background(), secretId, "versionStage-42") if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) } @@ -300,7 +301,7 @@ func TestGetSecretBinaryWithStage(t *testing.T) { ) for i := 0; i < 10; i++ { - result, err := secretCache.GetSecretBinaryWithStage(secretId, "versionStage-42") + result, err := secretCache.GetSecretBinaryWithStage(context.Background(), secretId, "versionStage-42") if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) } @@ -322,7 +323,7 @@ func TestGetSecretStringMultipleNotFound(t *testing.T) { ) for i := 0; i < 100; i++ { - _, err := secretCache.GetSecretStringWithStage("test", "versionStage-42") + _, err := secretCache.GetSecretStringWithStage(context.Background(), "test", "versionStage-42") if err == nil { t.Fatalf("Expected error: secretNotFound for a missing secret") @@ -345,7 +346,7 @@ func TestGetSecretBinaryMultipleNotFound(t *testing.T) { ) for i := 0; i < 100; i++ { - _, err := secretCache.GetSecretBinaryWithStage("test", "versionStage-42") + _, err := secretCache.GetSecretBinaryWithStage(context.Background(), "test", "versionStage-42") if err == nil { t.Fatalf("Expected error: secretNotFound for a missing secret") @@ -364,7 +365,7 @@ func TestGetSecretVersionStageEmpty(t *testing.T) { func(c *secretcache.Cache) { c.Client = &mockClient }, ) - result, err := secretCache.GetSecretStringWithStage("test", "") + result, err := secretCache.GetSecretStringWithStage(context.Background(), "test", "") if err != nil { t.Fatalf("Unexpected error - %s", err.Error()) @@ -380,7 +381,7 @@ func TestGetSecretVersionStageEmpty(t *testing.T) { func(c *secretcache.Cache) { c.CacheConfig.VersionStage = "" }, ) - result, err = secretCache.GetSecretStringWithStage("test", "") + result, err = secretCache.GetSecretStringWithStage(context.Background(), "test", "") if err != nil { t.Fatalf("Unexpected error - %s", err.Error())