Skip to content

Commit 78703a3

Browse files
as14692EC2 Default User
and
EC2 Default User
authored
credentialspec: initialize s3 client with bucket's region (#3886)
* Fix: gMSA s3 test * Update NewS3Client * Updated mocks * Updated unit tests * Create mocks for modified S3ClientCreator interface * Modify mocks for S3ClientCreator interface * Add error handling on creating a new S3 client --------- Co-authored-by: EC2 Default User <[email protected]>
1 parent 6ba1af7 commit 78703a3

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

agent/s3/factory/factory.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ const (
3434

3535
type S3ClientCreator interface {
3636
NewS3ManagerClient(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3ManagerClient, error)
37-
NewS3Client(region string, creds credentials.IAMRoleCredentials) s3client.S3Client
37+
NewS3Client(bucket, region string, creds credentials.IAMRoleCredentials) (s3client.S3Client, error)
3838
}
3939

4040
// NewS3ClientCreator provide 2 implementations
@@ -65,15 +65,21 @@ func (*s3ClientCreator) NewS3ManagerClient(bucket, region string,
6565
}
6666

6767
// NewS3Client returns a new S3 client to support s3 operations which are not provided by s3manager.
68-
func (*s3ClientCreator) NewS3Client(region string,
69-
creds credentials.IAMRoleCredentials) s3client.S3Client {
68+
func (*s3ClientCreator) NewS3Client(bucket, region string,
69+
creds credentials.IAMRoleCredentials) (s3client.S3Client, error) {
7070
cfg := aws.NewConfig().
7171
WithHTTPClient(httpclient.New(roundtripTimeout, false)).
7272
WithCredentials(
7373
awscreds.NewStaticCredentials(creds.AccessKeyID, creds.SecretAccessKey,
7474
creds.SessionToken)).WithRegion(region)
7575
sess := session.Must(session.NewSession(cfg))
76-
return s3.New(sess)
76+
svc := s3.New(sess)
77+
bucketRegion, err := getRegionFromBucket(svc, bucket)
78+
if err != nil {
79+
return nil, err
80+
}
81+
sessWithRegion := session.Must(session.NewSession(cfg.WithRegion(bucketRegion)))
82+
return s3.New(sessWithRegion), nil
7783
}
7884
func getRegionFromBucket(svc *s3.S3, bucket string) (string, error) {
7985
input := &s3.GetBucketLocationInput{

agent/s3/factory/mocks/factory_mocks.go

+6-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/taskresource/credentialspec/credentialspec_linux.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,12 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialS
435435
return
436436
}
437437

438-
s3Client := cs.s3ClientCreator.NewS3Client(cs.region, iamCredentials)
438+
s3Client, err := cs.s3ClientCreator.NewS3Client(bucket, cs.region, iamCredentials)
439+
if err != nil {
440+
cs.setTerminalReason(err.Error())
441+
errorEvents <- err
442+
return
443+
}
439444

440445
credSpecJsonStringUnformatted, err := s3.GetObject(bucket, key, s3Client)
441446

agent/taskresource/credentialspec/credentialspec_linux_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValue(t *testing.T) {
374374
Body: io.NopCloser(strings.NewReader(testData)),
375375
}
376376
gomock.InOrder(
377-
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
377+
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
378378
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1),
379379
)
380380

@@ -439,7 +439,7 @@ func TestHandleS3DomainlessCredentialSpecFileGetS3SecretValue(t *testing.T) {
439439
Body: io.NopCloser(strings.NewReader(testData)),
440440
}
441441
gomock.InOrder(
442-
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
442+
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
443443
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(s3GetObjectResponse, nil).Times(1),
444444
)
445445

@@ -501,7 +501,7 @@ func TestHandleS3CredentialSpecFileGetS3SecretValueErr(t *testing.T) {
501501
}, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning)
502502

503503
gomock.InOrder(
504-
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any()).Return(mockS3Client),
504+
s3ClientCreator.EXPECT().NewS3Client(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Client, nil),
505505
mockS3Client.EXPECT().GetObject(gomock.Any()).Return(nil, errors.New("test-error")).Times(1),
506506
)
507507

0 commit comments

Comments
 (0)