diff --git a/backend/go.mod b/backend/go.mod index e1b77f4e44..87c3493b01 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -14,6 +14,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/iam v1.32.0 github.com/aws/aws-sdk-go-v2/service/kinesis v1.27.4 github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 + github.com/aws/aws-sdk-go-v2/service/s3control v1.44.7 github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 github.com/aws/smithy-go v1.20.2 github.com/bradleyfalzon/ghinstallation/v2 v2.7.0 diff --git a/backend/go.sum b/backend/go.sum index 7ddc69fe8f..67005106c9 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -111,6 +111,8 @@ github.com/aws/aws-sdk-go-v2/service/kinesis v1.27.4 h1:Oe8awBiS/iitcsRJB5+DHa3i github.com/aws/aws-sdk-go-v2/service/kinesis v1.27.4/go.mod h1:RCZCSFbieSgNG1RKegO26opXV4EXyef/vNBVJsUyHuw= github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 h1:6cnno47Me9bRykw9AEv9zkXE+5or7jz8TsskTTccbgc= github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1/go.mod h1:qmdkIIAC+GCLASF7R2whgNrJADz0QZPX+Seiw/i4S3o= +github.com/aws/aws-sdk-go-v2/service/s3control v1.44.7 h1:tpUe6VAwhNsOJRzxSUNypRnLHInLGTFDXECKIdvGxJw= +github.com/aws/aws-sdk-go-v2/service/s3control v1.44.7/go.mod h1:xywJi2/waU8+fglbs5ASVHKr5y7OAYsEBOyQwgQgTIc= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= diff --git a/backend/mock/service/awsmock/awsmock.go b/backend/mock/service/awsmock/awsmock.go index 18962fdadf..7e162fe397 100644 --- a/backend/mock/service/awsmock/awsmock.go +++ b/backend/mock/service/awsmock/awsmock.go @@ -13,6 +13,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam" iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3control" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go/middleware" "github.com/golang/protobuf/ptypes/any" @@ -29,6 +30,13 @@ import ( type svc struct{} +func (s *svc) S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName string, accountID string) (*s3control.GetAccessPointPolicyOutput, error) { + return &s3control.GetAccessPointPolicyOutput{ + Policy: aws.String("{}"), + ResultMetadata: middleware.Metadata{}, + }, nil +} + func (s *svc) GetDirectClient(account string, region string) (clutchawsclient.DirectClient, error) { panic("implement me") } diff --git a/backend/service/aws/aws.go b/backend/service/aws/aws.go index 6192b6e8bd..8affadf166 100644 --- a/backend/service/aws/aws.go +++ b/backend/service/aws/aws.go @@ -24,6 +24,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3control" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/iancoleman/strcase" "github.com/uber-go/tally/v4" @@ -155,6 +156,7 @@ func (c *client) createRegionalClients(accountAlias, region string, regions []st }, s3: s3.NewFromConfig(regionCfg), + s3control: s3control.NewFromConfig(regionCfg), kinesis: kinesis.NewFromConfig(regionCfg), ec2: ec2.NewFromConfig(regionCfg), autoscaling: autoscaling.NewFromConfig(regionCfg), @@ -178,6 +180,7 @@ type Client interface { S3GetBucketPolicy(ctx context.Context, account, region, bucket, accountID string) (*s3.GetBucketPolicyOutput, error) S3StreamingGet(ctx context.Context, account, region, bucket, key string) (io.ReadCloser, error) + S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName, accountID string) (*s3control.GetAccessPointPolicyOutput, error) DescribeTable(ctx context.Context, account, region, tableName string) (*dynamodbv1.Table, error) UpdateCapacity(ctx context.Context, account, region, tableName string, targetTableCapacity *dynamodbv1.Throughput, indexUpdates []*dynamodbv1.IndexUpdateAction, ignoreMaximums bool) (*dynamodbv1.Table, error) BatchGetItem(ctx context.Context, account, region string, params *dynamodb.BatchGetItemInput) (*dynamodb.BatchGetItemOutput, error) @@ -232,6 +235,7 @@ type regionalClient struct { iam iamClient kinesis kinesisClient s3 s3Client + s3control s3ControlClient sts stsClient } diff --git a/backend/service/aws/iface.go b/backend/service/aws/iface.go index 3b826c4ae5..e5bca28d94 100644 --- a/backend/service/aws/iface.go +++ b/backend/service/aws/iface.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3control" "github.com/aws/aws-sdk-go-v2/service/sts" ) @@ -20,6 +21,9 @@ type s3Client interface { GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyOutput, error) } +type s3ControlClient interface { + GetAccessPointPolicy(ctx context.Context, params *s3control.GetAccessPointPolicyInput, optFns ...func(*s3control.Options)) (*s3control.GetAccessPointPolicyOutput, error) +} type stsClient interface { AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) diff --git a/backend/service/aws/s3control.go b/backend/service/aws/s3control.go new file mode 100644 index 0000000000..0ae6246d1a --- /dev/null +++ b/backend/service/aws/s3control.go @@ -0,0 +1,22 @@ +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3control" +) + +func (c *client) S3GetAccessPointPolicy(ctx context.Context, account, region, accessPointName, accountID string) (*s3control.GetAccessPointPolicyOutput, error) { + cl, err := c.getAccountRegionClient(account, region) + if err != nil { + return nil, err + } + + in := &s3control.GetAccessPointPolicyInput{ + AccountId: aws.String(accountID), + Name: aws.String(accessPointName), + } + + return cl.s3control.GetAccessPointPolicy(ctx, in) +} diff --git a/backend/service/aws/s3control_test.go b/backend/service/aws/s3control_test.go new file mode 100644 index 0000000000..d75493bb7a --- /dev/null +++ b/backend/service/aws/s3control_test.go @@ -0,0 +1,72 @@ +package aws + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3control" + "github.com/stretchr/testify/assert" +) + +func TestS3ControlGetAccessPointPolicy(t *testing.T) { + s3ControlClient := &mockS3Control{ + getAccessPointPolicyOutput: &s3control.GetAccessPointPolicyOutput{ + Policy: aws.String("policy"), + }, + } + + c := &client{ + currentAccountAlias: "default", + accounts: map[string]*accountClients{ + "default": { + clients: map[string]*regionalClient{ + "us-east-1": {region: "us-east-1", s3control: s3ControlClient}, + }, + }, + }, + } + + output, err := c.S3GetAccessPointPolicy(context.Background(), "default", "us-east-1", "access-point", "accountID") + assert.NoError(t, err) + assert.Equal(t, output.Policy, aws.String("policy")) +} + +func TestS3ControlGetAccessPointPolicyErrorHandling(t *testing.T) { + s3ControlClient := &mockS3Control{ + getAccessPointPolicyErr: fmt.Errorf("error"), + } + + c := &client{ + currentAccountAlias: "default", + accounts: map[string]*accountClients{ + "default": { + clients: map[string]*regionalClient{ + "us-east-1": {region: "us-east-1", s3control: s3ControlClient}, + }, + }, + }, + } + + output1, err1 := c.S3GetAccessPointPolicy(context.Background(), "default", "us-east-1", "access-point", "accountID") + assert.Nil(t, output1) + assert.Error(t, err1) + + // Test unknown region + output2, err2 := c.S3GetAccessPointPolicy(context.Background(), "default", "choice-region-1", "access-point", "accountID") + assert.Nil(t, output2) + assert.Error(t, err2) +} + +type mockS3Control struct { + getAccessPointPolicyOutput *s3control.GetAccessPointPolicyOutput + getAccessPointPolicyErr error +} + +func (m *mockS3Control) GetAccessPointPolicy(ctx context.Context, params *s3control.GetAccessPointPolicyInput, optFns ...func(*s3control.Options)) (*s3control.GetAccessPointPolicyOutput, error) { + if m.getAccessPointPolicyErr != nil { + return nil, m.getAccessPointPolicyErr + } + return m.getAccessPointPolicyOutput, nil +}