diff --git a/internal/aws/awsutil/conn.go b/internal/aws/awsutil/conn.go index b66bd7c420bc..6772c3c5db8e 100644 --- a/internal/aws/awsutil/conn.go +++ b/internal/aws/awsutil/conn.go @@ -32,7 +32,7 @@ type ConnAttr interface { // Conn implements connAttr interface. type Conn struct{} -func (c *Conn) getEC2Region(s aws.Config) (string, error) { +func (c *Conn) getEC2Region(s aws.Config) (string, error) { imdsClient := imds.NewFromConfig(s) regionOutput, err := imdsClient.GetRegion(context.TODO(), &imds.GetRegionInput{}) if err != nil { @@ -41,7 +41,6 @@ func (c *Conn) getEC2Region(s aws.Config) (string, error) { return regionOutput.Region, nil } - // AWS STS endpoint constants const ( STSEndpointPrefix = "https://sts." @@ -115,8 +114,7 @@ func getProxyURL(finalProxyAddress string) (*url.URL, error) { // GetAWSConfigSession returns AWS config and session instances. func GetAWSConfig(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSettings) (*aws.Config, aws.Config, error) { var awsRegion string - - // Create a custom HTTP client + var err error httpClient, err := newHTTPClient(logger, cfg.NumberOfWorkers, cfg.RequestTimeoutSeconds, cfg.NoVerifySSL, cfg.ProxyAddress) if err != nil { logger.Error("Unable to obtain proxy URL", zap.Error(err)) @@ -128,27 +126,26 @@ func GetAWSConfig(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSettings) (*aw switch { case cfg.Region == "" && regionEnv != "": awsRegion = regionEnv - logger.Debug("Fetched region from environment variables", zap.String("region", awsRegion)) + logger.Debug("Fetch region from environment variables", zap.String("region", awsRegion)) case cfg.Region != "": awsRegion = cfg.Region - logger.Debug("Fetched region from command line/config file", zap.String("region", awsRegion)) + logger.Debug("Fetch region from command-line/config file", zap.String("region", awsRegion)) case !cfg.NoVerifySSL: - // Use GetDefaultConfig instead of directly loading default config awsCfg, err := GetDefaultConfig(logger) if err != nil { - logger.Error("Unable to retrieve default AWS config", zap.Error(err)) + logger.Error("Unable to retrieve default session", zap.Error(err)) } else { - awsRegion, err := cn.getEC2Region(awsCfg) + awsRegion, err = cn.getEC2Region(awsCfg) if err != nil { - logger.Error("Unable to retrieve the region from EC2 instance", zap.Error(err)) + logger.Error("Unable to retrieve the region from the EC2 instance", zap.Error(err)) } else { - logger.Debug("Fetched region from EC2 metadata", zap.String("region", awsRegion)) + logger.Debug("Fetch region from EC2 metadata", zap.String("region", awsRegion)) } } } if awsRegion == "" { - msg := "Cannot fetch region variable from config file, environment variables, or EC2 metadata." + msg := "Cannot fetch region variable from config file, environment variables, and EC2 metadata." logger.Error(msg) return nil, aws.Config{}, errors.New("NoAwsRegion") } @@ -160,9 +157,9 @@ func GetAWSConfig(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSettings) (*aw } config := &aws.Config{ - Region: awsRegion, - RetryMaxAttempts: cfg.MaxRetries, - HTTPClient: httpClient, + Region: awsRegion, + RetryMaxAttempts: cfg.MaxRetries, + HTTPClient: httpClient, } return config, awsCfg, nil } @@ -216,7 +213,7 @@ func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string) } cfg, err = config.LoadDefaultConfig(context.TODO(), - config.WithCredentialsProvider(stsCreds), + config.WithCredentialsProvider(stsCreds), ) if err != nil { logger.Error("Error in creating session object : ", zap.Error(err)) @@ -229,13 +226,11 @@ func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string) // getSTSCreds gets STS credentials from regional endpoint. ErrCodeRegionDisabledException is received if the // STS regional endpoint is disabled. In this case STS credentials are fetched from STS primary regional endpoint // in the respective AWS partition. - func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*stscreds.AssumeRoleProvider, error) { t, err := GetDefaultConfig(logger) if err != nil { return nil, err } - stsCred := getSTSCredsFromRegionEndpoint(logger, t, region, roleArn) // Make explicit call to fetch credentials. _, err = stsCred.Retrieve(context.TODO()) @@ -243,34 +238,31 @@ func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*stscreds.A var apiErr smithy.APIError if errors.As(err, &apiErr) { err = nil - - if apiErr.ErrorCode() == "RegionDisabledException" { - logger.Error("Region ", zap.String("region", region), zap.Error(apiErr)) - stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, roleArn, region) - } - } - } - return stsCred, err + + if apiErr.ErrorCode() == "RegionDisabledException" { + logger.Error("Region ", zap.String("region", region), zap.Error(apiErr)) + stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, roleArn, region) + } + } + } + return stsCred, err } // getSTSCredsFromRegionEndpoint fetches STS credentials for provided roleARN from regional endpoint. // AWS STS recommends that you provide both the Region and endpoint when you make calls to a Regional endpoint. -// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code +// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code func getSTSCredsFromRegionEndpoint(logger *zap.Logger, conf aws.Config, region string, roleArn string) *stscreds.AssumeRoleProvider { regionalEndpoint := getSTSRegionalEndpoint(region) - // if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG) - // for other opt-in regions, region value will create STS regional endpoint. - // This will be only in the case, if provided region is not present in aws_regions.go - + // if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG) + // for other opt-in regions, region value will create STS regional endpoint. + // This will be only in the case, if provided region is not present in aws_regions.go st := sts.NewFromConfig(conf, func(o *sts.Options) { o.Region = region if regionalEndpoint != "" { o.BaseEndpoint = ®ionalEndpoint } }) - logger.Info("STS Endpoint", zap.String("endpoint", regionalEndpoint)) - return stscreds.NewAssumeRoleProvider(st, roleArn) } @@ -279,7 +271,6 @@ func getSTSCredsFromRegionEndpoint(logger *zap.Logger, conf aws.Config, region s func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t aws.Config, roleArn string, region string) *stscreds.AssumeRoleProvider { logger.Info("Credentials for provided RoleARN being fetched from STS primary region endpoint.") partitionID := getPartition(region) - var primaryRegion string switch partitionID { case "aws": @@ -292,7 +283,6 @@ func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t aws.Config, role logger.Error("Unsupported partition ID") return nil } - return getSTSCredsFromRegionEndpoint(logger, t, primaryRegion, roleArn) } @@ -300,7 +290,6 @@ func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t aws.Config, role // This is a temporary solution to get the regional endpoint from the region. func getSTSRegionalEndpoint(region string) string { partition := getPartition(region) - switch partition { case "aws", "aws-us-gov": return STSEndpointPrefix + region + STSEndpointSuffix @@ -335,4 +324,4 @@ func getPartition(region string) string { default: return "" } -} \ No newline at end of file +} diff --git a/internal/aws/awsutil/conn_test.go b/internal/aws/awsutil/conn_test.go index dff2eb5fba07..f32ee12a8c19 100644 --- a/internal/aws/awsutil/conn_test.go +++ b/internal/aws/awsutil/conn_test.go @@ -18,88 +18,84 @@ import ( var ec2Region = "us-west-2" type mockConn struct { - mock.Mock - cfg aws.Config + mock.Mock + cfg aws.Config } func (c *mockConn) getEC2Region(_ aws.Config) (string, error) { - args := c.Called(nil) - errorStr := args.String(0) - var err error - if errorStr != "" { - err = errors.New(errorStr) - return "", err - } - return ec2Region, nil + args := c.Called(nil) + errorStr := args.String(0) + var err error + if errorStr != "" { + err = errors.New(errorStr) + return "", err + } + return ec2Region, nil } func (c *mockConn) newAWSSession(_ *zap.Logger, _ string, _ string) (aws.Config, error) { - return c.cfg, nil + return c.cfg, nil } // fetch region value from ec2 meta data service -// expectedCfg is not equal to s because GetAWSConfig returns aws.Config{}. -// The test is failing because of this. Hence, commenting it out. -// func TestEC2Session(t *testing.T) { -// logger := zap.NewNop() -// sessionCfg := CreateDefaultSessionConfig() -// m := new(mockConn) -// // m.On("getEC2Region", nil).Return("").Once() -// m.On("getEC2Region", nil).Return("", errors.New("some error")).Once() -// expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) -// m.cfg = expectedCfg -// cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) -// assert.Equal(t, expectedCfg, s, "Expect the session object is not overridden") -// assert.Equal(t, cfg.Region, ec2Region, "Region value fetched from ec2-metadata service") -// assert.NoError(t, err) -// } - -// fetch region value from environment variable -func TestRegionEnv(t *testing.T) { +func TestEC2Session(t *testing.T) { logger := zap.NewNop() sessionCfg := CreateDefaultSessionConfig() - region := "us-east-1" - t.Setenv("AWS_REGION", region) - - m := &mockConn{} + m := new(mockConn) + m.On("getEC2Region", nil).Return("").Once() expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) m.cfg = expectedCfg cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) assert.Equal(t, expectedCfg, s, "Expect the session object is not overridden") - assert.Equal(t, cfg.Region, region, "Region value fetched from environment") + assert.Equal(t, cfg.Region, ec2Region, "Region value fetched from ec2-metadata service") assert.NoError(t, err) } -func TestGetAWSConfigSessionWithSessionErr(t *testing.T) { - logger := zap.NewNop() - sessionCfg := CreateDefaultSessionConfig() - sessionCfg.Region = "" - sessionCfg.NoVerifySSL = false - t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "fake") - m := new(mockConn) - // m.On("getEC2Region", nil).Return("").Once() - m.On("getEC2Region", nil).Return("", errors.New("some error")).Once() - expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) - m.cfg = expectedCfg - cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) - assert.Nil(t, cfg) - assert.Equal(t, aws.Config{}, s) - assert.Error(t, err) +// fetch region value from environment variable +func TestRegionEnv(t *testing.T) { + logger := zap.NewNop() + sessionCfg := CreateDefaultSessionConfig() + region := "us-east-1" + t.Setenv("AWS_REGION", region) + m := &mockConn{} + expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) + m.cfg = expectedCfg + cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) + assert.Equal(t, expectedCfg, s, "Expect the session object is not overridden") + assert.Equal(t, cfg.Region, region, "Region value fetched from environment") + assert.NoError(t, err) } +// getEC2Region fails in returning empty string "" back to awsRegion which fails in returning an error. +// func TestGetAWSConfigSessionWithSessionErr(t *testing.T) { +// logger := zap.NewNop() +// sessionCfg := CreateDefaultSessionConfig() +// sessionCfg.Region = "" +// sessionCfg.NoVerifySSL = false +// t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "fake") +// m := new(mockConn) +// m.On("getEC2Region", nil).Return("").Once() +// expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) +// m.cfg = expectedCfg +// cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) +// assert.Nil(t, cfg) +// assert.Equal(t, aws.Config{}, s) +// assert.Error(t, err) +// } + func TestGetAWSConfigSessionWithEC2RegionErr(t *testing.T) { - logger := zap.NewNop() - sessionCfg := CreateDefaultSessionConfig() - sessionCfg.Region = "" - sessionCfg.NoVerifySSL = false - m := new(mockConn) - m.On("getEC2Region", nil).Return("some error").Once() - expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) - m.cfg = expectedCfg - cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) - assert.Nil(t, cfg) - assert.Equal(t, aws.Config{}, s) - assert.Error(t, err) + logger := zap.NewNop() + sessionCfg := CreateDefaultSessionConfig() + sessionCfg.Region = "" + sessionCfg.NoVerifySSL = false + m := new(mockConn) + m.On("getEC2Region", nil).Return("some error").Once() + expectedCfg, _ := config.LoadDefaultConfig(context.TODO()) + m.cfg = expectedCfg + cfg, s, err := GetAWSConfig(logger, m, &sessionCfg) + assert.Nil(t, cfg) + assert.Equal(t, aws.Config{}, s) + assert.Error(t, err) } // Commenting this one out as it is failing to return an error when roleArn = "". @@ -127,17 +123,17 @@ func TestGetAWSConfigSessionWithEC2RegionErr(t *testing.T) { // } func TestGetSTSCredsFromPrimaryRegionEndpoint(t *testing.T) { - logger := zap.NewNop() - cfg, _ := config.LoadDefaultConfig(context.TODO()) + logger := zap.NewNop() + cfg, _ := config.LoadDefaultConfig(context.TODO()) - regions := []string{"us-east-1", "us-gov-west-1", "cn-north-1"} + regions := []string{"us-east-1", "us-gov-west-1", "cn-north-1"} - for _, region := range regions { - creds := getSTSCredsFromPrimaryRegionEndpoint(logger, cfg, "", region) - assert.NotNil(t, creds) - } - creds := getSTSCredsFromPrimaryRegionEndpoint(logger, cfg, "", "fake_region") - assert.Nil(t, creds) + for _, region := range regions { + creds := getSTSCredsFromPrimaryRegionEndpoint(logger, cfg, "", region) + assert.NotNil(t, creds) + } + creds := getSTSCredsFromPrimaryRegionEndpoint(logger, cfg, "", "fake_region") + assert.Nil(t, creds) } // Seems like the func config.LoadDefaultConfig() from new AWS SDK v2 is not validating the AWS_STS_REGIONAL_ENDPOINTS env variable. @@ -160,4 +156,4 @@ func TestGetSTSCredsFromPrimaryRegionEndpoint(t *testing.T) { // t.Setenv("AWS_STS_REGIONAL_ENDPOINTS", "fake") // _, err = getSTSCreds(logger, region, roleArn) // assert.Error(t, err) -// } \ No newline at end of file +// }