diff --git a/controlplane/eks/controllers/awsmanagedcontrolplane_controller_test.go b/controlplane/eks/controllers/awsmanagedcontrolplane_controller_test.go index 2247b92ea5..8a22464c87 100644 --- a/controlplane/eks/controllers/awsmanagedcontrolplane_controller_test.go +++ b/controlplane/eks/controllers/awsmanagedcontrolplane_controller_test.go @@ -20,20 +20,19 @@ import ( "context" "encoding/base64" "fmt" - "net/http" "strconv" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" + signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/eks" ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/aws/aws-sdk-go-v2/service/iam" iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" - stsrequest "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/sts" + stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" "github.com/golang/mock/gomock" . "github.com/onsi/gomega" @@ -54,8 +53,8 @@ import ( "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/iamauth/mock_iamauth" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/network" - "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/securitygroup" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface" "sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/cluster-api/util" @@ -76,7 +75,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) { ec2Mock *mocks.MockEC2API eksMock *mock_eksiface.MockEKSAPI iamMock *mock_iamauth.MockIAMAPI - stsMock *mock_stsiface.MockSTSAPI + stsMock *mock_stsiface.MockSTSClient awsNodeMock *mock_services.MockAWSNodeInterface iamAuthenticatorMock *mock_services.MockIAMAuthenticatorInterface kubeProxyMock *mock_services.MockKubeProxyInterface @@ -96,7 +95,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) { ec2Mock = mocks.NewMockEC2API(mockCtrl) eksMock = mock_eksiface.NewMockEKSAPI(mockCtrl) iamMock = mock_iamauth.NewMockIAMAPI(mockCtrl) - stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl) + stsMock = mock_stsiface.NewMockSTSClient(mockCtrl) // Mocking these as well, since the actual implementation requires a remote client to an actual cluster awsNodeMock = mock_services.NewMockAWSNodeInterface(mockCtrl) @@ -854,7 +853,7 @@ func mockedEKSControlPlaneIAMRole(g *WithT, iamRec *mock_iamauth.MockIAMAPIMockR }).After(getPolicyCall).Return(&iam.AttachRolePolicyOutput{}, nil) } -func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSAPIMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) { +func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSClientMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) { describeClusterCall := eksRec.DescribeCluster(ctx, &eks.DescribeClusterInput{ Name: aws.String("test-cluster"), }).Return(nil, &ekstypes.ResourceNotFoundException{ @@ -948,12 +947,14 @@ func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockE })).Return( clusterSgDesc, nil) - req, err := http.NewRequest(http.MethodGet, "foobar", http.NoBody) - g.Expect(err).To(BeNil()) - stsRec.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}).Return(&stsrequest.Request{ - HTTPRequest: req, - Operation: &stsrequest.Operation{}, - }, &sts.GetCallerIdentityOutput{}) + stsRec.PresignGetCallerIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return(&signerv4.PresignedHTTPRequest{ + URL: "https://example.com", + }, nil) + stsRec.GetCallerIdentity(gomock.Any(), gomock.Any()).Return(&stsv2.GetCallerIdentityOutput{ + Account: aws.String("123456789012"), + Arn: aws.String("arn:aws:iam::123456789012:user/test-user"), + UserId: aws.String("AIDACKCEVSQ6C2EXAMPLE"), + }, nil) eksRec.TagResource(ctx, &eks.TagResourceInput{ ResourceArn: clusterARN, diff --git a/controlplane/rosa/controllers/rosacontrolplane_controller.go b/controlplane/rosa/controllers/rosacontrolplane_controller.go index dc29d03eb9..afe5bf24c0 100644 --- a/controlplane/rosa/controllers/rosacontrolplane_controller.go +++ b/controlplane/rosa/controllers/rosacontrolplane_controller.go @@ -30,8 +30,6 @@ import ( "time" stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" - sts "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/google/go-cmp/cmp" idputils "github.com/openshift-online/ocm-common/pkg/idp/utils" cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" @@ -62,6 +60,7 @@ import ( "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/annotations" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" + stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/utils" @@ -92,7 +91,7 @@ type ROSAControlPlaneReconciler struct { WatchFilterValue string WaitInfraPeriod time.Duration Endpoints []scope.ServiceEndpoint - NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI + NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) // Exposing the restClientConfig for integration test. No need to initialize. restClientConfig *restclient.Config @@ -221,7 +220,11 @@ func (r *ROSAControlPlaneReconciler) reconcileNormal(ctx context.Context, rosaSc return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err) } - creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity)) + creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{ + Account: rosaScope.Identity.Account, + Arn: rosaScope.Identity.Arn, + UserId: rosaScope.Identity.UserId, + }) if err != nil { return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err) } @@ -354,7 +357,11 @@ func (r *ROSAControlPlaneReconciler) reconcileDelete(ctx context.Context, rosaSc return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err) } - creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity)) + creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{ + Account: rosaScope.Identity.Account, + Arn: rosaScope.Identity.Arn, + UserId: rosaScope.Identity.UserId, + }) if err != nil { return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err) } @@ -1130,12 +1137,3 @@ func buildAPIEndpoint(cluster *cmv1.Cluster) (*clusterv1.APIEndpoint, error) { Port: int32(port), //#nosec G109 G115 }, nil } - -// TODO: Remove this and update the aws-sdk lib to v2. -func convertStsV2(identity *sts.GetCallerIdentityOutput) *stsv2.GetCallerIdentityOutput { - return &stsv2.GetCallerIdentityOutput{ - Account: identity.Account, - Arn: identity.Arn, - UserId: identity.UserId, - } -} diff --git a/controlplane/rosa/controllers/rosacontrolplane_controller_test.go b/controlplane/rosa/controllers/rosacontrolplane_controller_test.go index 51dfb053de..41a5e6b4ee 100644 --- a/controlplane/rosa/controllers/rosacontrolplane_controller_test.go +++ b/controlplane/rosa/controllers/rosacontrolplane_controller_test.go @@ -28,9 +28,8 @@ import ( "testing" "time" + stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws" - sts "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/golang/mock/gomock" . "github.com/onsi/gomega" v1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" @@ -48,7 +47,8 @@ import ( rosacontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/rosa/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" - "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface" + stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa" "sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks" @@ -292,10 +292,10 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) { mockCtrl := gomock.NewController(t) ctx := context.TODO() ocmMock := mocks.NewMockOCMClient(mockCtrl) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) - getCallerIdentityResult := &sts.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")} - stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(getCallerIdentityResult, nil).Times(1) + getCallerIdentityResult := &stsv2.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")} + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(getCallerIdentityResult, nil).Times(1) expect := func(m *mocks.MockOCMClientMockRecorder) { m.ValidateHypershiftVersion(gomock.Any(), gomock.Any()).DoAndReturn(func(clusterId string, nodePoolID string) (bool, error) { @@ -396,7 +396,9 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) { Endpoints: []scope.ServiceEndpoint{}, Client: testEnv, restClientConfig: cfg, - NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock }, + NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient { + return stsMock + }, NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) { return ocmMock, nil }, diff --git a/exp/controllers/awsmachinepool_controller_test.go b/exp/controllers/awsmachinepool_controller_test.go index 71f25c86de..db2879ea86 100644 --- a/exp/controllers/awsmachinepool_controller_test.go +++ b/exp/controllers/awsmachinepool_controller_test.go @@ -51,7 +51,7 @@ import ( "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services" s3svc "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_s3iface" - "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/userdata" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -71,7 +71,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) { asgSvc *mock_services.MockASGInterface reconSvc *mock_services.MockMachinePoolReconcileInterface s3Mock *mock_s3iface.MockS3API - stsMock *mock_stsiface.MockSTSAPI + stsMock *mock_stsiface.MockSTSClient recorder *record.FakeRecorder awsMachinePool *expinfrav1.AWSMachinePool secret *corev1.Secret @@ -182,7 +182,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) { asgSvc = mock_services.NewMockASGInterface(mockCtrl) reconSvc = mock_services.NewMockMachinePoolReconcileInterface(mockCtrl) s3Mock = mock_s3iface.NewMockS3API(mockCtrl) - stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl) + stsMock = mock_stsiface.NewMockSTSClient(mockCtrl) // If the test hangs for 9 minutes, increase the value here to the number of events during a reconciliation loop recorder = record.NewFakeRecorder(2) diff --git a/exp/controllers/rosamachinepool_controller.go b/exp/controllers/rosamachinepool_controller.go index eab94fa3e9..64b7a3e65a 100644 --- a/exp/controllers/rosamachinepool_controller.go +++ b/exp/controllers/rosamachinepool_controller.go @@ -7,7 +7,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/blang/semver" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -35,6 +34,7 @@ import ( expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" + stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa" "sigs.k8s.io/cluster-api-provider-aws/v2/util/paused" @@ -52,7 +52,7 @@ type ROSAMachinePoolReconciler struct { Recorder record.EventRecorder WatchFilterValue string Endpoints []scope.ServiceEndpoint - NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI + NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsservice.STSClient NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) } diff --git a/exp/controllers/rosamachinepool_controller_test.go b/exp/controllers/rosamachinepool_controller_test.go index f33971d181..00e2dbc6b7 100644 --- a/exp/controllers/rosamachinepool_controller_test.go +++ b/exp/controllers/rosamachinepool_controller_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/golang/mock/gomock" . "github.com/onsi/gomega" cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1" @@ -26,7 +25,8 @@ import ( expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" - "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface" + stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa" "sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks" @@ -546,15 +546,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) { ocmMock := mocks.NewMockOCMClient(mockCtrl) test.expect(ocmMock.EXPECT()) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) - stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1) r := ROSAMachinePoolReconciler{ Recorder: recorder, WatchFilterValue: "", Endpoints: []scope.ServiceEndpoint{}, Client: testEnv, - NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock }, + NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient { + return stsMock + }, NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) { return ocmMock, nil }, @@ -641,15 +643,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) { } expect(ocmMock.EXPECT()) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) - stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1) r := ROSAMachinePoolReconciler{ Recorder: recorder, WatchFilterValue: "", Endpoints: []scope.ServiceEndpoint{}, Client: testEnv, - NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock }, + NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient { + return stsMock + }, NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) { return ocmMock, nil }, diff --git a/pkg/cloud/endpointsv2/endpoints.go b/pkg/cloud/endpointsv2/endpoints.go index b9d92aab18..597415050e 100644 --- a/pkg/cloud/endpointsv2/endpoints.go +++ b/pkg/cloud/endpointsv2/endpoints.go @@ -32,6 +32,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/sts" smithyendpoints "github.com/aws/smithy-go/endpoints" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" @@ -303,3 +304,25 @@ func (s *SSMEndpointResolver) ResolveEndpoint(ctx context.Context, params ssm.En params.Region = &endpoint.SigningRegion return ssm.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) } + +// STSEndpointResolver implements EndpointResolverV2 interface for STS. +type STSEndpointResolver struct { + *MultiServiceEndpointResolver +} + +// ResolveEndpoint for STS. +func (s *STSEndpointResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (smithyendpoints.Endpoint, error) { + // If custom endpoint not found, return default endpoint for the service + log := logger.FromContext(ctx) + endpoint, ok := s.endpoints[sts.ServiceID] + + if !ok { + log.Debug("Custom endpoint not found, using default endpoint") + return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) + } + + log.Debug("Custom endpoint found, using custom endpoint", "endpoint", endpoint.URL) + params.Endpoint = &endpoint.URL + params.Region = &endpoint.SigningRegion + return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params) +} diff --git a/pkg/cloud/scope/clients.go b/pkg/cloud/scope/clients.go index 1e37a9b200..d00123b966 100644 --- a/pkg/cloud/scope/clients.go +++ b/pkg/cloud/scope/clients.go @@ -28,13 +28,12 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/ssm" + stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "k8s.io/apimachinery/pkg/runtime" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud" @@ -42,6 +41,7 @@ import ( awslogs "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/logs" awsmetrics "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/metrics" awsmetricsv2 "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/metricsv2" + stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/throttle" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/record" @@ -270,13 +270,26 @@ func NewIAMClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logg } // NewSTSClient creates a new STS API client for a given session. -func NewSTSClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) stsiface.STSAPI { - stsClient := sts.New(session.Session(), aws.NewConfig().WithLogLevel(awslogs.GetAWSLogLevel(logger.GetLogger())).WithLogger(awslogs.NewWrapLogr(logger.GetLogger()))) - stsClient.Handlers.Build.PushFrontNamed(getUserAgentHandler()) - stsClient.Handlers.CompleteAttempt.PushFront(awsmetrics.CaptureRequestMetrics(scopeUser.ControllerName())) - stsClient.Handlers.Complete.PushBack(recordAWSPermissionsIssue(target)) +func NewSTSClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) stsservice.STSClient { + cfg := session.SessionV2() + multiSvcEndpointResolver := endpointsv2.NewMultiServiceEndpointResolver() + stsEndpointResolver := &endpointsv2.STSEndpointResolver{ + MultiServiceEndpointResolver: multiSvcEndpointResolver, + } + + stsOpts := []func(*stsv2.Options){ + func(o *stsv2.Options) { + o.Logger = logger.GetAWSLogger() + o.ClientLogMode = awslogs.GetAWSLogLevelV2(logger.GetLogger()) + o.EndpointResolverV2 = stsEndpointResolver + }, + stsv2.WithAPIOptions( + awsmetricsv2.WithMiddlewares(scopeUser.ControllerName(), target), + awsmetricsv2.WithCAPAUserAgentMiddleware(), + ), + } - return stsClient + return stsservice.NewClientWrapper(stsv2.NewFromConfig(cfg, stsOpts...)) } // NewSSMClient creates a new Secrets API client for a given session. diff --git a/pkg/cloud/scope/rosacontrolplane.go b/pkg/cloud/scope/rosacontrolplane.go index 5404102ad4..f2d781cdf9 100644 --- a/pkg/cloud/scope/rosacontrolplane.go +++ b/pkg/cloud/scope/rosacontrolplane.go @@ -21,9 +21,8 @@ import ( "fmt" awsv2 "github.com/aws/aws-sdk-go-v2/aws" + stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" awsclient "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -34,6 +33,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" rosacontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/rosa/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud" + stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/throttle" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" @@ -48,7 +48,7 @@ type ROSAControlPlaneScopeParams struct { ControlPlane *rosacontrolplanev1.ROSAControlPlane ControllerName string Endpoints []ServiceEndpoint - NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI + NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsservice.STSClient } // NewROSAControlPlaneScope creates a new ROSAControlPlaneScope from the supplied parameters. @@ -95,7 +95,7 @@ func NewROSAControlPlaneScope(params ROSAControlPlaneScopeParams) (*ROSAControlP managedScope.serviceLimitersV2 = serviceLimitersv2 stsClient := params.NewStsClient(managedScope, managedScope, managedScope, managedScope.ControlPlane) - identity, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + identity, err := stsClient.GetCallerIdentity(context.TODO(), &stsv2.GetCallerIdentityInput{}) if err != nil { return nil, fmt.Errorf("failed to identify the AWS caller: %w", err) } @@ -118,7 +118,7 @@ type ROSAControlPlaneScope struct { serviceLimiters throttle.ServiceLimiters serviceLimitersV2 throttle.ServiceLimiters controllerName string - Identity *sts.GetCallerIdentityOutput + Identity *stsv2.GetCallerIdentityOutput } // InfraCluster returns the AWSManagedControlPlane object. diff --git a/pkg/cloud/services/eks/config.go b/pkg/cloud/services/eks/config.go index 8207638ecd..153f293682 100644 --- a/pkg/cloud/services/eks/config.go +++ b/pkg/cloud/services/eks/config.go @@ -20,10 +20,10 @@ import ( "context" "encoding/base64" "fmt" - "time" ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts" + smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -45,6 +45,8 @@ const ( tokenPrefix = "k8s-aws-v1." //nolint:gosec clusterNameHeader = "x-k8s-aws-id" tokenAgeMins = 15 + xAmzExpiresHeader = "X-Amz-Expires" + xAmzExpires = 60 relativeKubeconfigKey = "relative" relativeTokenFileKey = "token-file" @@ -301,17 +303,28 @@ func (s *Service) createBaseKubeConfig(cluster *ekstypes.Cluster, userName strin func (s *Service) generateToken() (string, error) { eksClusterName := s.scope.KubernetesClusterName() - - req, output := s.STSClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - req.HTTPRequest.Header.Add(clusterNameHeader, eksClusterName) - s.Trace("generating token for AWS identity", "user", output.UserId, "account", output.Account, "arn", output.Arn) - - presignedURL, err := req.Presign(tokenAgeMins * time.Minute) + ctx := context.Background() + + presignedReq, err := s.STSClient.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) { + po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, + smithyhttp.SetHeaderValue(clusterNameHeader, eksClusterName), + smithyhttp.SetHeaderValue(xAmzExpiresHeader, fmt.Sprintf("%d", xAmzExpires)), + ) + }) + }) if err != nil { return "", fmt.Errorf("presigning AWS get caller identity: %w", err) } - encodedURL := base64.RawURLEncoding.EncodeToString([]byte(presignedURL)) + output, err := s.STSClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return "", fmt.Errorf("getting AWS caller identity: %w", err) + } + + s.Trace("generating token for AWS identity", "user", output.UserId, "account", output.Account, "arn", output.Arn) + + encodedURL := base64.RawURLEncoding.EncodeToString([]byte(presignedReq.URL)) return fmt.Sprintf("%s%s", tokenPrefix, encodedURL), nil } diff --git a/pkg/cloud/services/eks/config_test.go b/pkg/cloud/services/eks/config_test.go index b956dd1b5e..d6f64bd071 100644 --- a/pkg/cloud/services/eks/config_test.go +++ b/pkg/cloud/services/eks/config_test.go @@ -2,14 +2,12 @@ package eks import ( "context" - "net/http" - "net/url" "testing" + "github.com/aws/aws-sdk-go-v2/aws" + signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/golang/mock/gomock" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -36,26 +34,19 @@ func Test_createCAPIKubeconfigSecret(t *testing.T) { { name: "create kubeconfig secret", input: &ekstypes.Cluster{ - CertificateAuthority: &ekstypes.Certificate{Data: aws.String("")}, - Endpoint: aws.String("https://F00BA4.gr4.us-east-2.eks.amazonaws.com"), + Name: aws.String("cluster-foo"), + CertificateAuthority: &ekstypes.Certificate{Data: aws.String("LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0t")}, + Endpoint: aws.String("https://cluster-foo.us-east-2.eks.amazonaws.com"), }, serviceFunc: func() *Service { mockCtrl := gomock.NewController(t) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) - op := request.Request{ - Operation: &request.Operation{Name: "GetCallerIdentity", - HTTPMethod: "POST", - HTTPPath: "/", - }, - HTTPRequest: &http.Request{ - Header: make(http.Header), - URL: &url.URL{ - Scheme: "https", - Host: "F00BA4.gr4.us-east-2.eks.amazonaws.com", - }, - }, - } - stsMock.EXPECT().GetCallerIdentityRequest(gomock.Any()).Return(&op, &sts.GetCallerIdentityOutput{}) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) + stsMock.EXPECT().PresignGetCallerIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return(&signerv4.PresignedHTTPRequest{URL: "https://example.com"}, nil) + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(&sts.GetCallerIdentityOutput{ + UserId: aws.String("FAKEUSERID"), + Account: aws.String("FAKEACCOUNT"), + Arn: aws.String("arn:aws:sts::FAKEACCOUNT:user/FAKEUSERID"), + }, nil).AnyTimes() scheme := runtime.NewScheme() _ = infrav1.AddToScheme(scheme) @@ -150,21 +141,13 @@ func Test_updateCAPIKubeconfigSecret(t *testing.T) { }, serviceFunc: func(tc testCase) *Service { mockCtrl := gomock.NewController(t) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) - op := request.Request{ - Operation: &request.Operation{Name: "GetCallerIdentity", - HTTPMethod: "POST", - HTTPPath: "/", - }, - HTTPRequest: &http.Request{ - Header: make(http.Header), - URL: &url.URL{ - Scheme: "https", - Host: "F00BA4.gr4.us-east-2.eks.amazonaws.com", - }, - }, - } - stsMock.EXPECT().GetCallerIdentityRequest(gomock.Any()).Return(&op, &sts.GetCallerIdentityOutput{}) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) + stsMock.EXPECT().PresignGetCallerIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return(&signerv4.PresignedHTTPRequest{URL: "https://example.com"}, nil) + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(&sts.GetCallerIdentityOutput{ + UserId: aws.String("FAKEUSERID"), + Account: aws.String("FAKEACCOUNT"), + Arn: aws.String("arn:aws:sts::FAKEACCOUNT:user/FAKEUSERID"), + }, nil).AnyTimes() scheme := runtime.NewScheme() _ = infrav1.AddToScheme(scheme) diff --git a/pkg/cloud/services/eks/service.go b/pkg/cloud/services/eks/service.go index b893801969..9d1ab00c7f 100644 --- a/pkg/cloud/services/eks/service.go +++ b/pkg/cloud/services/eks/service.go @@ -23,12 +23,12 @@ import ( "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/aws/aws-sdk-go-v2/service/eks" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/common" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/eks/iam" + stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" ) // EKSAPI defines the EKS API interface. @@ -89,7 +89,7 @@ type Service struct { EC2Client common.EC2API EKSClient EKSAPI iam.IAMService - STSClient stsiface.STSAPI + STSClient stsservice.STSClient } // ServiceOpts defines the functional arguments for the service. @@ -134,7 +134,7 @@ type NodegroupService struct { AutoscalingClient *autoscaling.Client EKSClient EKSAPI iam.IAMService - STSClient stsiface.STSAPI + STSClient stsservice.STSClient } // NewNodegroupService returns a new service given the api clients. @@ -159,7 +159,7 @@ type FargateService struct { scope *scope.FargateProfileScope EKSClient EKSAPI iam.IAMService - STSClient stsiface.STSAPI + STSClient stsservice.STSClient } // NewFargateService returns a new service given the api clients. diff --git a/pkg/cloud/services/s3/mock_stsiface/doc.go b/pkg/cloud/services/s3/mock_stsiface/doc.go deleted file mode 100644 index 429a95b586..0000000000 --- a/pkg/cloud/services/s3/mock_stsiface/doc.go +++ /dev/null @@ -1,22 +0,0 @@ -/* -Copyright 2019 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package mock_stsiface provides a mock implementation for the STSAPI interface. -// Run go generate to regenerate this mock. -// -//go:generate ../../../../../hack/tools/bin/mockgen -destination stsapi_mock.go -package mock_stsiface github.com/aws/aws-sdk-go/service/sts/stsiface STSAPI -//go:generate /usr/bin/env bash -c "cat ../../../../../hack/boilerplate/boilerplate.generatego.txt stsapi_mock.go > _stsapi_mock.go && mv _stsapi_mock.go stsapi_mock.go" -package mock_stsiface //nolint:stylecheck diff --git a/pkg/cloud/services/s3/s3.go b/pkg/cloud/services/s3/s3.go index 64fca21dbb..982ba46e92 100644 --- a/pkg/cloud/services/s3/s3.go +++ b/pkg/cloud/services/s3/s3.go @@ -28,9 +28,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" + stsv2 "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/pkg/errors" "k8s.io/utils/ptr" @@ -39,6 +38,7 @@ import ( iam "sigs.k8s.io/cluster-api-provider-aws/v2/iam/api/v1beta1" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/awserrors" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" + stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/userdata" "sigs.k8s.io/cluster-api-provider-aws/v2/util/system" ) @@ -53,7 +53,7 @@ type Service struct { scope scope.S3Scope S3Client S3API S3PresignClient *s3.PresignClient - STSClient stsiface.STSAPI + STSClient stsservice.STSClient } // S3API is the subset of the AWS S3 API that is used by CAPA. @@ -512,7 +512,7 @@ func (s *Service) tagBucket(ctx context.Context, bucketName string) error { } func (s *Service) bucketPolicy(bucketName string) (string, error) { - accountID, err := s.STSClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + accountID, err := s.STSClient.GetCallerIdentity(context.Background(), &stsv2.GetCallerIdentityInput{}) if err != nil { return "", errors.Wrap(err, "getting account ID") } diff --git a/pkg/cloud/services/s3/s3_test.go b/pkg/cloud/services/s3/s3_test.go index c1a638a025..378d3114d3 100644 --- a/pkg/cloud/services/s3/s3_test.go +++ b/pkg/cloud/services/s3/s3_test.go @@ -29,7 +29,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" s3svc "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/golang/mock/gomock" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -42,7 +42,7 @@ import ( "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_s3iface" - "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface" + "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" ) @@ -117,10 +117,10 @@ func TestReconcileBucket(t *testing.T) { mockCtrl := gomock.NewController(t) s3Mock := mock_s3iface.NewMockS3API(mockCtrl) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) getCallerIdentityResult := &sts.GetCallerIdentityOutput{Account: aws.String("foo")} - stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(getCallerIdentityResult, nil).AnyTimes() + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(getCallerIdentityResult, nil).AnyTimes() scheme := runtime.NewScheme() _ = infrav1.AddToScheme(scheme) @@ -298,8 +298,8 @@ func TestReconcileBucket(t *testing.T) { s3Mock.EXPECT().PutBucketTagging(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) mockCtrl := gomock.NewController(t) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) - stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(nil, errors.New(t.Name())).AnyTimes() + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(nil, errors.New(t.Name())).AnyTimes() svc.STSClient = stsMock if err := svc.ReconcileBucket(context.TODO()); err == nil { @@ -896,10 +896,10 @@ func testService(t *testing.T, si *testServiceInput) (*s3.Service, *mock_s3iface mockCtrl := gomock.NewController(t) s3Mock := mock_s3iface.NewMockS3API(mockCtrl) - stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl) + stsMock := mock_stsiface.NewMockSTSClient(mockCtrl) getCallerIdentityResult := &sts.GetCallerIdentityOutput{Account: aws.String("foo")} - stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(getCallerIdentityResult, nil).AnyTimes() + stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(getCallerIdentityResult, nil).AnyTimes() scheme := runtime.NewScheme() _ = infrav1.AddToScheme(scheme) diff --git a/pkg/cloud/services/sts/mock_stsiface/doc.go b/pkg/cloud/services/sts/mock_stsiface/doc.go index 1c576fa536..648e4b4ef8 100644 --- a/pkg/cloud/services/sts/mock_stsiface/doc.go +++ b/pkg/cloud/services/sts/mock_stsiface/doc.go @@ -14,9 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package mock_stsiface provides a mock implementation for the STSAPI interface. +// Package mock_stsiface provides a mock implementation for the STSClient interface. // Run go generate to regenerate this mock. // -//go:generate ../../../../../hack/tools/bin/mockgen -destination stsiface_mock.go -package mock_stsiface github.com/aws/aws-sdk-go/service/sts/stsiface STSAPI -//go:generate /usr/bin/env bash -c "cat ../../../../../hack/boilerplate/boilerplate.generatego.txt stsiface_mock.go > _stsiface_mock.go && mv _stsiface_mock.go stsiface_mock.go" +//go:generate ../../../../../hack/tools/bin/mockgen -destination stsiface_mock_v2.go -package mock_stsiface sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts STSClient +//go:generate /usr/bin/env bash -c "cat ../../../../../hack/boilerplate/boilerplate.generatego.txt stsiface_mock_v2.go > _stsiface_mock_v2.go && mv _stsiface_mock_v2.go stsiface_mock_v2.go" +//go:generate ../../../../../hack/tools/bin/mockgen -destination stsiface_mock_v1.go -package mock_stsiface github.com/aws/aws-sdk-go/service/sts/stsiface STSAPI +//go:generate /usr/bin/env bash -c "cat ../../../../../hack/boilerplate/boilerplate.generatego.txt stsiface_mock_v1.go > _stsiface_mock_v1.go && mv _stsiface_mock_v1.go stsiface_mock_v1.go" package mock_stsiface //nolint:stylecheck diff --git a/pkg/cloud/services/sts/mock_stsiface/stsiface_mock.go b/pkg/cloud/services/sts/mock_stsiface/stsiface_mock.go deleted file mode 100644 index 047c9491fc..0000000000 --- a/pkg/cloud/services/sts/mock_stsiface/stsiface_mock.go +++ /dev/null @@ -1,453 +0,0 @@ -/* -Copyright The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/aws-sdk-go/service/sts/stsiface (interfaces: STSAPI) - -// Package mock_stsiface is a generated GoMock package. -package mock_stsiface - -import ( - context "context" - reflect "reflect" - - request "github.com/aws/aws-sdk-go/aws/request" - sts "github.com/aws/aws-sdk-go/service/sts" - gomock "github.com/golang/mock/gomock" -) - -// MockSTSAPI is a mock of STSAPI interface. -type MockSTSAPI struct { - ctrl *gomock.Controller - recorder *MockSTSAPIMockRecorder -} - -// MockSTSAPIMockRecorder is the mock recorder for MockSTSAPI. -type MockSTSAPIMockRecorder struct { - mock *MockSTSAPI -} - -// NewMockSTSAPI creates a new mock instance. -func NewMockSTSAPI(ctrl *gomock.Controller) *MockSTSAPI { - mock := &MockSTSAPI{ctrl: ctrl} - mock.recorder = &MockSTSAPIMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSTSAPI) EXPECT() *MockSTSAPIMockRecorder { - return m.recorder -} - -// AssumeRole mocks base method. -func (m *MockSTSAPI) AssumeRole(arg0 *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssumeRole", arg0) - ret0, _ := ret[0].(*sts.AssumeRoleOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AssumeRole indicates an expected call of AssumeRole. -func (mr *MockSTSAPIMockRecorder) AssumeRole(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRole", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRole), arg0) -} - -// AssumeRoleRequest mocks base method. -func (m *MockSTSAPI) AssumeRoleRequest(arg0 *sts.AssumeRoleInput) (*request.Request, *sts.AssumeRoleOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssumeRoleRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.AssumeRoleOutput) - return ret0, ret1 -} - -// AssumeRoleRequest indicates an expected call of AssumeRoleRequest. -func (mr *MockSTSAPIMockRecorder) AssumeRoleRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleRequest", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleRequest), arg0) -} - -// AssumeRoleWithContext mocks base method. -func (m *MockSTSAPI) AssumeRoleWithContext(arg0 context.Context, arg1 *sts.AssumeRoleInput, arg2 ...request.Option) (*sts.AssumeRoleOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "AssumeRoleWithContext", varargs...) - ret0, _ := ret[0].(*sts.AssumeRoleOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AssumeRoleWithContext indicates an expected call of AssumeRoleWithContext. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithContext", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithContext), varargs...) -} - -// AssumeRoleWithSAML mocks base method. -func (m *MockSTSAPI) AssumeRoleWithSAML(arg0 *sts.AssumeRoleWithSAMLInput) (*sts.AssumeRoleWithSAMLOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssumeRoleWithSAML", arg0) - ret0, _ := ret[0].(*sts.AssumeRoleWithSAMLOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AssumeRoleWithSAML indicates an expected call of AssumeRoleWithSAML. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithSAML(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithSAML", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithSAML), arg0) -} - -// AssumeRoleWithSAMLRequest mocks base method. -func (m *MockSTSAPI) AssumeRoleWithSAMLRequest(arg0 *sts.AssumeRoleWithSAMLInput) (*request.Request, *sts.AssumeRoleWithSAMLOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssumeRoleWithSAMLRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.AssumeRoleWithSAMLOutput) - return ret0, ret1 -} - -// AssumeRoleWithSAMLRequest indicates an expected call of AssumeRoleWithSAMLRequest. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithSAMLRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithSAMLRequest", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithSAMLRequest), arg0) -} - -// AssumeRoleWithSAMLWithContext mocks base method. -func (m *MockSTSAPI) AssumeRoleWithSAMLWithContext(arg0 context.Context, arg1 *sts.AssumeRoleWithSAMLInput, arg2 ...request.Option) (*sts.AssumeRoleWithSAMLOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "AssumeRoleWithSAMLWithContext", varargs...) - ret0, _ := ret[0].(*sts.AssumeRoleWithSAMLOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AssumeRoleWithSAMLWithContext indicates an expected call of AssumeRoleWithSAMLWithContext. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithSAMLWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithSAMLWithContext", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithSAMLWithContext), varargs...) -} - -// AssumeRoleWithWebIdentity mocks base method. -func (m *MockSTSAPI) AssumeRoleWithWebIdentity(arg0 *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssumeRoleWithWebIdentity", arg0) - ret0, _ := ret[0].(*sts.AssumeRoleWithWebIdentityOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AssumeRoleWithWebIdentity indicates an expected call of AssumeRoleWithWebIdentity. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithWebIdentity(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithWebIdentity", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithWebIdentity), arg0) -} - -// AssumeRoleWithWebIdentityRequest mocks base method. -func (m *MockSTSAPI) AssumeRoleWithWebIdentityRequest(arg0 *sts.AssumeRoleWithWebIdentityInput) (*request.Request, *sts.AssumeRoleWithWebIdentityOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssumeRoleWithWebIdentityRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.AssumeRoleWithWebIdentityOutput) - return ret0, ret1 -} - -// AssumeRoleWithWebIdentityRequest indicates an expected call of AssumeRoleWithWebIdentityRequest. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithWebIdentityRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithWebIdentityRequest", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithWebIdentityRequest), arg0) -} - -// AssumeRoleWithWebIdentityWithContext mocks base method. -func (m *MockSTSAPI) AssumeRoleWithWebIdentityWithContext(arg0 context.Context, arg1 *sts.AssumeRoleWithWebIdentityInput, arg2 ...request.Option) (*sts.AssumeRoleWithWebIdentityOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "AssumeRoleWithWebIdentityWithContext", varargs...) - ret0, _ := ret[0].(*sts.AssumeRoleWithWebIdentityOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AssumeRoleWithWebIdentityWithContext indicates an expected call of AssumeRoleWithWebIdentityWithContext. -func (mr *MockSTSAPIMockRecorder) AssumeRoleWithWebIdentityWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRoleWithWebIdentityWithContext", reflect.TypeOf((*MockSTSAPI)(nil).AssumeRoleWithWebIdentityWithContext), varargs...) -} - -// DecodeAuthorizationMessage mocks base method. -func (m *MockSTSAPI) DecodeAuthorizationMessage(arg0 *sts.DecodeAuthorizationMessageInput) (*sts.DecodeAuthorizationMessageOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecodeAuthorizationMessage", arg0) - ret0, _ := ret[0].(*sts.DecodeAuthorizationMessageOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DecodeAuthorizationMessage indicates an expected call of DecodeAuthorizationMessage. -func (mr *MockSTSAPIMockRecorder) DecodeAuthorizationMessage(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodeAuthorizationMessage", reflect.TypeOf((*MockSTSAPI)(nil).DecodeAuthorizationMessage), arg0) -} - -// DecodeAuthorizationMessageRequest mocks base method. -func (m *MockSTSAPI) DecodeAuthorizationMessageRequest(arg0 *sts.DecodeAuthorizationMessageInput) (*request.Request, *sts.DecodeAuthorizationMessageOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecodeAuthorizationMessageRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.DecodeAuthorizationMessageOutput) - return ret0, ret1 -} - -// DecodeAuthorizationMessageRequest indicates an expected call of DecodeAuthorizationMessageRequest. -func (mr *MockSTSAPIMockRecorder) DecodeAuthorizationMessageRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodeAuthorizationMessageRequest", reflect.TypeOf((*MockSTSAPI)(nil).DecodeAuthorizationMessageRequest), arg0) -} - -// DecodeAuthorizationMessageWithContext mocks base method. -func (m *MockSTSAPI) DecodeAuthorizationMessageWithContext(arg0 context.Context, arg1 *sts.DecodeAuthorizationMessageInput, arg2 ...request.Option) (*sts.DecodeAuthorizationMessageOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DecodeAuthorizationMessageWithContext", varargs...) - ret0, _ := ret[0].(*sts.DecodeAuthorizationMessageOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DecodeAuthorizationMessageWithContext indicates an expected call of DecodeAuthorizationMessageWithContext. -func (mr *MockSTSAPIMockRecorder) DecodeAuthorizationMessageWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodeAuthorizationMessageWithContext", reflect.TypeOf((*MockSTSAPI)(nil).DecodeAuthorizationMessageWithContext), varargs...) -} - -// GetAccessKeyInfo mocks base method. -func (m *MockSTSAPI) GetAccessKeyInfo(arg0 *sts.GetAccessKeyInfoInput) (*sts.GetAccessKeyInfoOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAccessKeyInfo", arg0) - ret0, _ := ret[0].(*sts.GetAccessKeyInfoOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetAccessKeyInfo indicates an expected call of GetAccessKeyInfo. -func (mr *MockSTSAPIMockRecorder) GetAccessKeyInfo(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccessKeyInfo", reflect.TypeOf((*MockSTSAPI)(nil).GetAccessKeyInfo), arg0) -} - -// GetAccessKeyInfoRequest mocks base method. -func (m *MockSTSAPI) GetAccessKeyInfoRequest(arg0 *sts.GetAccessKeyInfoInput) (*request.Request, *sts.GetAccessKeyInfoOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAccessKeyInfoRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.GetAccessKeyInfoOutput) - return ret0, ret1 -} - -// GetAccessKeyInfoRequest indicates an expected call of GetAccessKeyInfoRequest. -func (mr *MockSTSAPIMockRecorder) GetAccessKeyInfoRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccessKeyInfoRequest", reflect.TypeOf((*MockSTSAPI)(nil).GetAccessKeyInfoRequest), arg0) -} - -// GetAccessKeyInfoWithContext mocks base method. -func (m *MockSTSAPI) GetAccessKeyInfoWithContext(arg0 context.Context, arg1 *sts.GetAccessKeyInfoInput, arg2 ...request.Option) (*sts.GetAccessKeyInfoOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetAccessKeyInfoWithContext", varargs...) - ret0, _ := ret[0].(*sts.GetAccessKeyInfoOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetAccessKeyInfoWithContext indicates an expected call of GetAccessKeyInfoWithContext. -func (mr *MockSTSAPIMockRecorder) GetAccessKeyInfoWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccessKeyInfoWithContext", reflect.TypeOf((*MockSTSAPI)(nil).GetAccessKeyInfoWithContext), varargs...) -} - -// GetCallerIdentity mocks base method. -func (m *MockSTSAPI) GetCallerIdentity(arg0 *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCallerIdentity", arg0) - ret0, _ := ret[0].(*sts.GetCallerIdentityOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetCallerIdentity indicates an expected call of GetCallerIdentity. -func (mr *MockSTSAPIMockRecorder) GetCallerIdentity(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentity", reflect.TypeOf((*MockSTSAPI)(nil).GetCallerIdentity), arg0) -} - -// GetCallerIdentityRequest mocks base method. -func (m *MockSTSAPI) GetCallerIdentityRequest(arg0 *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetCallerIdentityRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.GetCallerIdentityOutput) - return ret0, ret1 -} - -// GetCallerIdentityRequest indicates an expected call of GetCallerIdentityRequest. -func (mr *MockSTSAPIMockRecorder) GetCallerIdentityRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentityRequest", reflect.TypeOf((*MockSTSAPI)(nil).GetCallerIdentityRequest), arg0) -} - -// GetCallerIdentityWithContext mocks base method. -func (m *MockSTSAPI) GetCallerIdentityWithContext(arg0 context.Context, arg1 *sts.GetCallerIdentityInput, arg2 ...request.Option) (*sts.GetCallerIdentityOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetCallerIdentityWithContext", varargs...) - ret0, _ := ret[0].(*sts.GetCallerIdentityOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetCallerIdentityWithContext indicates an expected call of GetCallerIdentityWithContext. -func (mr *MockSTSAPIMockRecorder) GetCallerIdentityWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentityWithContext", reflect.TypeOf((*MockSTSAPI)(nil).GetCallerIdentityWithContext), varargs...) -} - -// GetFederationToken mocks base method. -func (m *MockSTSAPI) GetFederationToken(arg0 *sts.GetFederationTokenInput) (*sts.GetFederationTokenOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFederationToken", arg0) - ret0, _ := ret[0].(*sts.GetFederationTokenOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetFederationToken indicates an expected call of GetFederationToken. -func (mr *MockSTSAPIMockRecorder) GetFederationToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFederationToken", reflect.TypeOf((*MockSTSAPI)(nil).GetFederationToken), arg0) -} - -// GetFederationTokenRequest mocks base method. -func (m *MockSTSAPI) GetFederationTokenRequest(arg0 *sts.GetFederationTokenInput) (*request.Request, *sts.GetFederationTokenOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFederationTokenRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.GetFederationTokenOutput) - return ret0, ret1 -} - -// GetFederationTokenRequest indicates an expected call of GetFederationTokenRequest. -func (mr *MockSTSAPIMockRecorder) GetFederationTokenRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFederationTokenRequest", reflect.TypeOf((*MockSTSAPI)(nil).GetFederationTokenRequest), arg0) -} - -// GetFederationTokenWithContext mocks base method. -func (m *MockSTSAPI) GetFederationTokenWithContext(arg0 context.Context, arg1 *sts.GetFederationTokenInput, arg2 ...request.Option) (*sts.GetFederationTokenOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetFederationTokenWithContext", varargs...) - ret0, _ := ret[0].(*sts.GetFederationTokenOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetFederationTokenWithContext indicates an expected call of GetFederationTokenWithContext. -func (mr *MockSTSAPIMockRecorder) GetFederationTokenWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFederationTokenWithContext", reflect.TypeOf((*MockSTSAPI)(nil).GetFederationTokenWithContext), varargs...) -} - -// GetSessionToken mocks base method. -func (m *MockSTSAPI) GetSessionToken(arg0 *sts.GetSessionTokenInput) (*sts.GetSessionTokenOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionToken", arg0) - ret0, _ := ret[0].(*sts.GetSessionTokenOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSessionToken indicates an expected call of GetSessionToken. -func (mr *MockSTSAPIMockRecorder) GetSessionToken(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionToken", reflect.TypeOf((*MockSTSAPI)(nil).GetSessionToken), arg0) -} - -// GetSessionTokenRequest mocks base method. -func (m *MockSTSAPI) GetSessionTokenRequest(arg0 *sts.GetSessionTokenInput) (*request.Request, *sts.GetSessionTokenOutput) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionTokenRequest", arg0) - ret0, _ := ret[0].(*request.Request) - ret1, _ := ret[1].(*sts.GetSessionTokenOutput) - return ret0, ret1 -} - -// GetSessionTokenRequest indicates an expected call of GetSessionTokenRequest. -func (mr *MockSTSAPIMockRecorder) GetSessionTokenRequest(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTokenRequest", reflect.TypeOf((*MockSTSAPI)(nil).GetSessionTokenRequest), arg0) -} - -// GetSessionTokenWithContext mocks base method. -func (m *MockSTSAPI) GetSessionTokenWithContext(arg0 context.Context, arg1 *sts.GetSessionTokenInput, arg2 ...request.Option) (*sts.GetSessionTokenOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetSessionTokenWithContext", varargs...) - ret0, _ := ret[0].(*sts.GetSessionTokenOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSessionTokenWithContext indicates an expected call of GetSessionTokenWithContext. -func (mr *MockSTSAPIMockRecorder) GetSessionTokenWithContext(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionTokenWithContext", reflect.TypeOf((*MockSTSAPI)(nil).GetSessionTokenWithContext), varargs...) -} diff --git a/pkg/cloud/services/s3/mock_stsiface/stsapi_mock.go b/pkg/cloud/services/sts/mock_stsiface/stsiface_mock_v1.go similarity index 100% rename from pkg/cloud/services/s3/mock_stsiface/stsapi_mock.go rename to pkg/cloud/services/sts/mock_stsiface/stsiface_mock_v1.go diff --git a/pkg/cloud/services/sts/mock_stsiface/stsiface_mock_v2.go b/pkg/cloud/services/sts/mock_stsiface/stsiface_mock_v2.go new file mode 100644 index 0000000000..a485439875 --- /dev/null +++ b/pkg/cloud/services/sts/mock_stsiface/stsiface_mock_v2.go @@ -0,0 +1,113 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Code generated by MockGen. DO NOT EDIT. +// Source: sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts (interfaces: STSClient) + +// Package mock_stsiface is a generated GoMock package. +package mock_stsiface + +import ( + context "context" + reflect "reflect" + + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + sts "github.com/aws/aws-sdk-go-v2/service/sts" + gomock "github.com/golang/mock/gomock" +) + +// MockSTSClient is a mock of STSClient interface. +type MockSTSClient struct { + ctrl *gomock.Controller + recorder *MockSTSClientMockRecorder +} + +// MockSTSClientMockRecorder is the mock recorder for MockSTSClient. +type MockSTSClientMockRecorder struct { + mock *MockSTSClient +} + +// NewMockSTSClient creates a new mock instance. +func NewMockSTSClient(ctrl *gomock.Controller) *MockSTSClient { + mock := &MockSTSClient{ctrl: ctrl} + mock.recorder = &MockSTSClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSTSClient) EXPECT() *MockSTSClientMockRecorder { + return m.recorder +} + +// AssumeRole mocks base method. +func (m *MockSTSClient) AssumeRole(arg0 context.Context, arg1 *sts.AssumeRoleInput, arg2 ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "AssumeRole", varargs...) + ret0, _ := ret[0].(*sts.AssumeRoleOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AssumeRole indicates an expected call of AssumeRole. +func (mr *MockSTSClientMockRecorder) AssumeRole(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssumeRole", reflect.TypeOf((*MockSTSClient)(nil).AssumeRole), varargs...) +} + +// GetCallerIdentity mocks base method. +func (m *MockSTSClient) GetCallerIdentity(arg0 context.Context, arg1 *sts.GetCallerIdentityInput, arg2 ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetCallerIdentity", varargs...) + ret0, _ := ret[0].(*sts.GetCallerIdentityOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCallerIdentity indicates an expected call of GetCallerIdentity. +func (mr *MockSTSClientMockRecorder) GetCallerIdentity(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCallerIdentity", reflect.TypeOf((*MockSTSClient)(nil).GetCallerIdentity), varargs...) +} + +// PresignGetCallerIdentity mocks base method. +func (m *MockSTSClient) PresignGetCallerIdentity(arg0 context.Context, arg1 *sts.GetCallerIdentityInput, arg2 ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "PresignGetCallerIdentity", varargs...) + ret0, _ := ret[0].(*v4.PresignedHTTPRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PresignGetCallerIdentity indicates an expected call of PresignGetCallerIdentity. +func (mr *MockSTSClientMockRecorder) PresignGetCallerIdentity(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PresignGetCallerIdentity", reflect.TypeOf((*MockSTSClient)(nil).PresignGetCallerIdentity), varargs...) +} diff --git a/pkg/cloud/services/sts/sts.go b/pkg/cloud/services/sts/sts.go new file mode 100644 index 0000000000..7ab8e7b3f8 --- /dev/null +++ b/pkg/cloud/services/sts/sts.go @@ -0,0 +1,64 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package sts provides an interface for AWS STS operations using the AWS SDK v2. +package sts + +import ( + "context" + + signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +// STSClient interface for STS operations using AWS SDK v2. +type STSClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*signerv4.PresignedHTTPRequest, error) + AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +} + +// ClientWrapper wraps both the regular STS client and presign client to implement STSClient interface. +type ClientWrapper struct { + client *sts.Client + presignClient *sts.PresignClient +} + +// NewClientWrapper creates a new STS client wrapper. +func NewClientWrapper(client *sts.Client) *ClientWrapper { + return &ClientWrapper{ + client: client, + presignClient: sts.NewPresignClient(client), + } +} + +// GetCallerIdentity calls the regular STS GetCallerIdentity operation. +func (c *ClientWrapper) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return c.client.GetCallerIdentity(ctx, params, optFns...) +} + +// PresignGetCallerIdentity creates a presigned URL for the GetCallerIdentity operation. +func (c *ClientWrapper) PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*signerv4.PresignedHTTPRequest, error) { + return c.presignClient.PresignGetCallerIdentity(ctx, params, optFns...) +} + +// AssumeRole calls the STS AssumeRole operation. +func (c *ClientWrapper) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { + return c.client.AssumeRole(ctx, params, optFns...) +} + +// Ensure our wrapper implements the STSClient interface. +var _ STSClient = (*ClientWrapper)(nil) diff --git a/test/e2e/shared/aws.go b/test/e2e/shared/aws.go index 7e7d9131b5..cf6e66040e 100644 --- a/test/e2e/shared/aws.go +++ b/test/e2e/shared/aws.go @@ -43,6 +43,7 @@ import ( elbtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing/types" "github.com/aws/aws-sdk-go-v2/service/iam" iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/client" awscreds "github.com/aws/aws-sdk-go/aws/credentials" @@ -53,7 +54,6 @@ import ( "github.com/aws/aws-sdk-go/service/ecrpublic" "github.com/aws/aws-sdk-go/service/efs" "github.com/aws/aws-sdk-go/service/servicequotas" - "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/smithy-go" cfn_iam "github.com/awslabs/goformation/v4/cloudformation/iam" . "github.com/onsi/ginkgo/v2" @@ -736,11 +736,11 @@ func GetPolicyArn(ctx context.Context, cfg awsv2.Config, name string) string { return "" } -func logAccountDetails(prov client.ConfigProvider) { +func logAccountDetails(cfg *awsv2.Config) { By("Getting AWS account details") - stsSvc := sts.New(prov) + stsSvc := sts.NewFromConfig(*cfg) - output, err := stsSvc.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + output, err := stsSvc.GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}) if err != nil { fmt.Fprintf(GinkgoWriter, "Couldn't get sts caller identity: err=%s\n", err) return diff --git a/test/e2e/shared/suite.go b/test/e2e/shared/suite.go index 027bfe3679..6d9f0bbe71 100644 --- a/test/e2e/shared/suite.go +++ b/test/e2e/shared/suite.go @@ -131,7 +131,7 @@ func Node1BeforeSuite(e2eCtx *E2EContext) []byte { e2eCtx.AWSSession = NewAWSSession() e2eCtx.AWSSessionV2 = NewAWSSessionV2() - logAccountDetails(e2eCtx.AWSSession) + logAccountDetails(e2eCtx.AWSSessionV2) bootstrapTemplate := getBootstrapTemplate(e2eCtx) bootstrapTags := map[string]string{"capa-e2e-test": "true"}