@@ -21,39 +21,53 @@ import (
21
21
"errors"
22
22
"fmt"
23
23
24
- kms "github.com/aws/aws-sdk-go-v2/service/secretsmanager"
25
- kmsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
24
+ "github.com/aws/aws-sdk-go-v2/aws"
25
+ secs "github.com/aws/aws-sdk-go-v2/service/secretsmanager"
26
+ secsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
26
27
"github.com/aws/aws-sdk-go-v2/service/ssm"
27
28
ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types"
29
+ "github.com/aws/aws-sdk-go-v2/service/sts"
30
+ "github.com/dwango/yashiro/internal/client/cache"
28
31
"github.com/dwango/yashiro/internal/values"
29
32
"github.com/dwango/yashiro/pkg/config"
30
33
)
31
34
32
- type ssmClient interface {
33
- GetParameter (ctx context.Context , params * ssm.GetParameterInput , optFns ... func (* ssm.Options )) (* ssm.GetParameterOutput , error )
34
- }
35
-
36
- type kmsClient interface {
37
- GetSecretValue (ctx context.Context , params * kms.GetSecretValueInput , optFns ... func (* kms.Options )) (* kms.GetSecretValueOutput , error )
38
- }
39
-
40
35
type awsClient struct {
41
36
ssmClient ssmClient
42
- kmsClient kmsClient
37
+ secsClient secsClient
43
38
parameterStoreValue []config.AwsParameterStoreValueConfig
44
39
secretsManagerValue []config.ValueConfig
45
40
}
46
41
47
- func newAwsClient (cfg * config.AwsConfig ) (Client , error ) {
48
- if cfg .SdkConfig == nil {
42
+ func newAwsClient (cfg * config.Config ) (Client , error ) {
43
+ if cfg .Aws . SdkConfig == nil {
49
44
return nil , fmt .Errorf ("require aws sdk config" )
50
45
}
51
46
47
+ var cc cache.Cache
48
+ if cfg .Global .EnableCache {
49
+ // get AWS account ID
50
+ accountID , err := getAwsAccountId (cfg .Aws .SdkConfig )
51
+ if err != nil {
52
+ return nil , err
53
+ }
54
+ cc , err = cache .New (cfg .Global .Cache , cache .WithCacheKeys ("aws" , cfg .Aws .SdkConfig .Region , accountID ))
55
+ if err != nil {
56
+ return nil , err
57
+ }
58
+ }
59
+
52
60
return & awsClient {
53
- ssmClient : ssm .NewFromConfig (* cfg .SdkConfig ),
54
- kmsClient : kms .NewFromConfig (* cfg .SdkConfig ),
55
- parameterStoreValue : cfg .ParameterStoreValues ,
56
- secretsManagerValue : cfg .SecretsManagerValues ,
61
+ ssmClient : & ssmClientWithCache {
62
+ client : ssm .NewFromConfig (* cfg .Aws .SdkConfig ),
63
+ cache : cc ,
64
+ },
65
+ secsClient : & secsClientWithCache {
66
+ client : secs .NewFromConfig (* cfg .Aws .SdkConfig ),
67
+ cache : cc ,
68
+ },
69
+ parameterStoreValue : cfg .Aws .ParameterStoreValues ,
70
+ secretsManagerValue : cfg .Aws .SecretsManagerValues ,
57
71
}, nil
58
72
}
59
73
@@ -80,12 +94,12 @@ func (c awsClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.V
80
94
}
81
95
82
96
for _ , v := range c .secretsManagerValue {
83
- output , err := c .kmsClient .GetSecretValue (ctx , & kms .GetSecretValueInput {
97
+ output , err := c .secsClient .GetSecretValue (ctx , & secs .GetSecretValueInput {
84
98
SecretId : & v .Name ,
85
99
})
86
100
87
101
if err != nil {
88
- var notFoundErr * kmsTypes .ResourceNotFoundException
102
+ var notFoundErr * secsTypes .ResourceNotFoundException
89
103
if ignoreNotFound && errors .As (err , & notFoundErr ) {
90
104
continue
91
105
}
@@ -99,3 +113,109 @@ func (c awsClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.V
99
113
100
114
return values , nil
101
115
}
116
+
117
+ type ssmClient interface {
118
+ GetParameter (ctx context.Context , params * ssm.GetParameterInput , optFns ... func (* ssm.Options )) (* ssm.GetParameterOutput , error )
119
+ }
120
+
121
+ type ssmClientWithCache struct {
122
+ client ssmClient
123
+ cache cache.Cache
124
+ }
125
+
126
+ func (c ssmClientWithCache ) GetParameter (ctx context.Context , params * ssm.GetParameterInput , optFns ... func (* ssm.Options )) (* ssm.GetParameterOutput , error ) {
127
+ if c .cache == nil {
128
+ return c .getParameter (ctx , params , optFns ... )
129
+ }
130
+
131
+ key := * params .Name // Name is required, so do not check nil
132
+ isSensitive := params .WithDecryption != nil && * params .WithDecryption
133
+
134
+ // Load from cache.
135
+ value , expired , err := c .cache .Load (ctx , key , isSensitive )
136
+ if err != nil {
137
+ return nil , err
138
+ }
139
+
140
+ // If a cache value is expired or not found, get a value from the external store.
141
+ if value == nil || expired {
142
+ output , err := c .getParameter (ctx , params , optFns ... )
143
+ if err != nil {
144
+ return nil , err
145
+ }
146
+
147
+ // Create or update cache.
148
+ if err := c .cache .Save (ctx , key , output .Parameter .Value , isSensitive ); err != nil {
149
+ return nil , err
150
+ }
151
+
152
+ return output , nil
153
+ }
154
+
155
+ return & ssm.GetParameterOutput {Parameter : & ssmTypes.Parameter {Value : value }}, nil
156
+ }
157
+
158
+ func (c ssmClientWithCache ) getParameter (ctx context.Context , params * ssm.GetParameterInput , optFns ... func (* ssm.Options )) (* ssm.GetParameterOutput , error ) {
159
+ output , err := c .client .GetParameter (ctx , params , optFns ... )
160
+ if err != nil {
161
+ return nil , err
162
+ }
163
+ return output , nil
164
+ }
165
+
166
+ type secsClient interface {
167
+ GetSecretValue (ctx context.Context , params * secs.GetSecretValueInput , optFns ... func (* secs.Options )) (* secs.GetSecretValueOutput , error )
168
+ }
169
+
170
+ type secsClientWithCache struct {
171
+ client secsClient
172
+ cache cache.Cache
173
+ }
174
+
175
+ func (c secsClientWithCache ) GetSecretValue (ctx context.Context , params * secs.GetSecretValueInput , optFns ... func (* secs.Options )) (* secs.GetSecretValueOutput , error ) {
176
+ if c .cache == nil {
177
+ return c .getSecretValue (ctx , params , optFns ... )
178
+ }
179
+
180
+ key := * params .SecretId // SecretId is required, so do not check nil
181
+
182
+ // Load from cache. Secret is always sensitive.
183
+ value , expired , err := c .cache .Load (ctx , key , true )
184
+ if err != nil {
185
+ return nil , err
186
+ }
187
+
188
+ // If a cache value is expired or not found, get a value from the external store.
189
+ if value == nil || expired {
190
+ output , err := c .getSecretValue (ctx , params , optFns ... )
191
+ if err != nil {
192
+ return nil , err
193
+ }
194
+
195
+ // Create or update cache.
196
+ if err := c .cache .Save (ctx , key , output .SecretString , true ); err != nil {
197
+ return nil , err
198
+ }
199
+
200
+ return output , nil
201
+ }
202
+
203
+ return & secs.GetSecretValueOutput {SecretString : value }, nil
204
+ }
205
+
206
+ func (c secsClientWithCache ) getSecretValue (ctx context.Context , params * secs.GetSecretValueInput , optFns ... func (* secs.Options )) (* secs.GetSecretValueOutput , error ) {
207
+ output , err := c .client .GetSecretValue (ctx , params , optFns ... )
208
+ if err != nil {
209
+ return nil , err
210
+ }
211
+ return output , nil
212
+ }
213
+
214
+ func getAwsAccountId (sdkConfig * aws.Config ) (string , error ) {
215
+ stsClient := sts .NewFromConfig (* sdkConfig )
216
+ output , err := stsClient .GetCallerIdentity (context .Background (), & sts.GetCallerIdentityInput {})
217
+ if err != nil {
218
+ return "" , err
219
+ }
220
+ return * output .Account , nil
221
+ }
0 commit comments