Skip to content

Commit

Permalink
refactor : Migrated from AWS SDK v1 to v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Avinash-Acharya committed Feb 13, 2025
1 parent af6b17e commit a7840c0
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 232 deletions.
236 changes: 135 additions & 101 deletions internal/aws/awsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,43 @@
package awsutil // import "github.com/open-telemetry/opentelemetry-collector-contrib/internal/aws/awsutil"

import (
"context"
"crypto/tls"
"errors"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/aws/smithy-go"
"go.uber.org/zap"
"golang.org/x/net/http2"
)

type ConnAttr interface {
newAWSSession(logger *zap.Logger, roleArn string, region string) (*session.Session, error)
getEC2Region(s *session.Session) (string, error)
newAWSSession(logger *zap.Logger, roleArn string, region string) (aws.Config, error)
getEC2Region(c aws.Config) (string, error)
}

// Conn implements connAttr interface.
type Conn struct{}

func (c *Conn) getEC2Region(s *session.Session) (string, error) {
return ec2metadata.New(s).Region()
func (c *Conn) getEC2Region(s aws.Config) (string, error) {
imdsClient := imds.NewFromConfig(s)
regionOutput, err := imdsClient.GetRegion(context.TODO(), &imds.GetRegionInput{})
if err != nil {
return "", err
}
return regionOutput.Region, nil
}


// AWS STS endpoint constants
const (
STSEndpointPrefix = "https://sts."
Expand Down Expand Up @@ -107,57 +113,58 @@ func getProxyURL(finalProxyAddress string) (*url.URL, error) {
}

// GetAWSConfigSession returns AWS config and session instances.
func GetAWSConfigSession(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSettings) (*aws.Config, *session.Session, error) {
var s *session.Session
var err error
func GetAWSConfig(logger *zap.Logger, cn ConnAttr, cfg *AWSSessionSettings) (*aws.Config, aws.Config, error) {
var awsRegion string
http, err := newHTTPClient(logger, cfg.NumberOfWorkers, cfg.RequestTimeoutSeconds, cfg.NoVerifySSL, cfg.ProxyAddress)

// Create a custom HTTP client
httpClient, err := newHTTPClient(logger, cfg.NumberOfWorkers, cfg.RequestTimeoutSeconds, cfg.NoVerifySSL, cfg.ProxyAddress)
if err != nil {
logger.Error("unable to obtain proxy URL", zap.Error(err))
return nil, nil, err
logger.Error("Unable to obtain proxy URL", zap.Error(err))
return nil, aws.Config{}, err
}

regionEnv := os.Getenv("AWS_REGION")

switch {
case cfg.Region == "" && regionEnv != "":
awsRegion = regionEnv
logger.Debug("Fetch region from environment variables", zap.String("region", awsRegion))
logger.Debug("Fetched region from environment variables", zap.String("region", awsRegion))
case cfg.Region != "":
awsRegion = cfg.Region
logger.Debug("Fetch region from commandline/config file", zap.String("region", awsRegion))
logger.Debug("Fetched region from command line/config file", zap.String("region", awsRegion))
case !cfg.NoVerifySSL:
var es *session.Session
es, err = GetDefaultSession(logger)
// Use GetDefaultConfig instead of directly loading default config
awsCfg, err := GetDefaultConfig(logger)
if err != nil {
logger.Error("Unable to retrieve default session", zap.Error(err))
logger.Error("Unable to retrieve default AWS config", zap.Error(err))
} else {
awsRegion, err = cn.getEC2Region(es)
awsRegion, err := cn.getEC2Region(awsCfg)
if err != nil {
logger.Error("Unable to retrieve the region from the EC2 instance", zap.Error(err))
logger.Error("Unable to retrieve the region from EC2 instance", zap.Error(err))
} else {
logger.Debug("Fetch region from ec2 metadata", zap.String("region", awsRegion))
logger.Debug("Fetched region from EC2 metadata", zap.String("region", awsRegion))
}
}
}

if awsRegion == "" {
msg := "Cannot fetch region variable from config file, environment variables and ec2 metadata."
msg := "Cannot fetch region variable from config file, environment variables, or EC2 metadata."
logger.Error(msg)
return nil, nil, awserr.New("NoAwsRegion", msg, nil)
return nil, aws.Config{}, errors.New("NoAwsRegion")
}
s, err = cn.newAWSSession(logger, cfg.RoleARN, awsRegion)

awsCfg, err := cn.newAWSSession(logger, cfg.RoleARN, awsRegion)
if err != nil {
return nil, nil, err
logger.Error("Failed to create AWS session", zap.Error(err))
return nil, aws.Config{}, err
}

config := &aws.Config{
Region: aws.String(awsRegion),
DisableParamValidation: aws.Bool(true),
MaxRetries: aws.Int(cfg.MaxRetries),
Endpoint: aws.String(cfg.Endpoint),
HTTPClient: http,
Region: awsRegion,
RetryMaxAttempts: cfg.MaxRetries,
HTTPClient: httpClient,
}
return config, s, nil
return config, awsCfg, nil
}

// ProxyServerTransport configures HTTP transport for TCP Proxy Server.
Expand Down Expand Up @@ -193,112 +200,139 @@ func ProxyServerTransport(logger *zap.Logger, config *AWSSessionSettings) (*http
return transport, nil
}

func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string) (*session.Session, error) {
var s *session.Session
func (c *Conn) newAWSSession(logger *zap.Logger, roleArn string, region string) (aws.Config, error) {
var cfg aws.Config
var err error
if roleArn == "" {
s, err = GetDefaultSession(logger)
cfg, err = GetDefaultConfig(logger)
if err != nil {
return s, err
return aws.Config{}, err
}
} else {
stsCreds, _ := getSTSCreds(logger, region, roleArn)
stsCreds, err := getSTSCreds(logger, region, roleArn)
if err != nil {
logger.Error("Error in getting STS credentials: ", zap.Error(err))
return aws.Config{}, err
}

s, err = session.NewSession(&aws.Config{
Credentials: stsCreds,
})
cfg, err = config.LoadDefaultConfig(context.TODO(),
config.WithCredentialsProvider(stsCreds),
)
if err != nil {
logger.Error("Error in creating session object : ", zap.Error(err))
return s, err
return aws.Config{}, err
}
}
return s, nil
return cfg, nil
}

// getSTSCreds gets STS credentials from regional endpoint. ErrCodeRegionDisabledException is received if the
// STS regional endpoint is disabled. In this case STS credentials are fetched from STS primary regional endpoint
// in the respective AWS partition.
func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*credentials.Credentials, error) {
t, err := GetDefaultSession(logger)

func getSTSCreds(logger *zap.Logger, region string, roleArn string) (*stscreds.AssumeRoleProvider, error) {
t, err := GetDefaultConfig(logger)
if err != nil {
return nil, err
}

stsCred := getSTSCredsFromRegionEndpoint(logger, t, region, roleArn)
// Make explicit call to fetch credentials.
_, err = stsCred.Get()
_, err = stsCred.Retrieve(context.TODO())
if err != nil {
var awsErr awserr.Error
if errors.As(err, &awsErr) {
var apiErr smithy.APIError
if errors.As(err, &apiErr) {
err = nil

if awsErr.Code() == sts.ErrCodeRegionDisabledException {
logger.Error("Region ", zap.String("region", region), zap.Error(awsErr))
stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, roleArn, region)
}
}
}
return stsCred, err
if apiErr.ErrorCode() == "RegionDisabledException" {
logger.Error("Region ", zap.String("region", region), zap.Error(apiErr))
stsCred = getSTSCredsFromPrimaryRegionEndpoint(logger, t, roleArn, region)
}
}
}
return stsCred, err
}

// getSTSCredsFromRegionEndpoint fetches STS credentials for provided roleARN from regional endpoint.
// AWS STS recommends that you provide both the Region and endpoint when you make calls to a Regional endpoint.
// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
func getSTSCredsFromRegionEndpoint(logger *zap.Logger, sess *session.Session, region string,
roleArn string,
) *credentials.Credentials {
// Reference: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html#id_credentials_temp_enable-regions_writing_code
func getSTSCredsFromRegionEndpoint(logger *zap.Logger, conf aws.Config, region string, roleArn string) *stscreds.AssumeRoleProvider {
regionalEndpoint := getSTSRegionalEndpoint(region)
// if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG)
// for other opt-in regions, region value will create STS regional endpoint.
// This will be only in the case, if provided region is not present in aws_regions.go
c := &aws.Config{Region: aws.String(region), Endpoint: &regionalEndpoint}
st := sts.New(sess, c)
logger.Info("STS Endpoint ", zap.String("endpoint", st.Endpoint))
return stscreds.NewCredentialsWithClient(st, roleArn)
// if regionalEndpoint is "", the STS endpoint is Global endpoint for classic regions except ap-east-1 - (HKG)
// for other opt-in regions, region value will create STS regional endpoint.
// This will be only in the case, if provided region is not present in aws_regions.go

st := sts.NewFromConfig(conf, func(o *sts.Options) {
o.Region = region
if regionalEndpoint != "" {
o.BaseEndpoint = &regionalEndpoint
}
})

logger.Info("STS Endpoint", zap.String("endpoint", regionalEndpoint))

return stscreds.NewAssumeRoleProvider(st, roleArn)
}

// getSTSCredsFromPrimaryRegionEndpoint fetches STS credentials for provided roleARN from primary region endpoint in
// the respective partition.
func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t *session.Session, roleArn string,
region string,
) *credentials.Credentials {
// TODO: Refactor this function once the Solution is found to provides a way to get the partition ID from the region.
// The partition ID is used to identify the AWS partition is a temporary solution to get the partition ID from the region.
func getSTSCredsFromPrimaryRegionEndpoint(logger *zap.Logger, t aws.Config, roleArn string, region string) *stscreds.AssumeRoleProvider {
logger.Info("Credentials for provided RoleARN being fetched from STS primary region endpoint.")
partitionID := getPartition(region)

var primaryRegion string
switch partitionID {
case endpoints.AwsPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsEast1RegionID, roleArn)
case endpoints.AwsCnPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.CnNorth1RegionID, roleArn)
case endpoints.AwsUsGovPartitionID:
return getSTSCredsFromRegionEndpoint(logger, t, endpoints.UsGovWest1RegionID, roleArn)
case "aws":
primaryRegion = "us-east-1"
case "aws-cn":
primaryRegion = "cn-north-1"
case "aws-us-gov":
primaryRegion = "us-gov-west-1"
default:
logger.Error("Unsupported partition ID")
return nil
}

return nil
return getSTSCredsFromRegionEndpoint(logger, t, primaryRegion, roleArn)
}

func getSTSRegionalEndpoint(r string) string {
p := getPartition(r)
// getSTSRegionalEndpoint returns the regional endpoint for the provided region.
// This is a temporary solution to get the regional endpoint from the region.
func getSTSRegionalEndpoint(region string) string {
partition := getPartition(region)

var e string
if p == endpoints.AwsPartitionID || p == endpoints.AwsUsGovPartitionID {
e = STSEndpointPrefix + r + STSEndpointSuffix
} else if p == endpoints.AwsCnPartitionID {
e = STSEndpointPrefix + r + STSAwsCnPartitionIDSuffix
switch partition {
case "aws", "aws-us-gov":
return STSEndpointPrefix + region + STSEndpointSuffix
case "aws-cn":
return STSEndpointPrefix + region + STSAwsCnPartitionIDSuffix
default:
return ""
}
return e
}

func GetDefaultSession(logger *zap.Logger) (*session.Session, error) {
result, serr := session.NewSession()
if serr != nil {
logger.Error("Error in creating session object ", zap.Error(serr))
return result, serr
func GetDefaultConfig(logger *zap.Logger) (aws.Config, error) {
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
logger.Error("Error in creating session object ", zap.Error(err))
return aws.Config{}, err
}
return result, nil
return cfg, nil
}

// getPartition return AWS Partition for the provided region.
// Currently, `endpoints` from AWS SDK Go v2 docs does not provide a way to get the partition ID from the region.
// This function is a temporary solution to get the partition ID from the region.
func getPartition(region string) string {
p, _ := endpoints.PartitionForRegion(endpoints.DefaultPartitions(), region)
return p.ID()
}
switch {
case strings.HasPrefix(region, "cn-"):
return "aws-cn" // AWS China Partition
case strings.HasPrefix(region, "us-gov-"):
return "aws-us-gov" // AWS GovCloud Partition
case strings.HasPrefix(region, "us"):
return "aws" // AWS Standard Partition
case strings.HasPrefix(region, "aws"):
return "aws" // AWS Partition
default:
return ""
}
}
Loading

0 comments on commit a7840c0

Please sign in to comment.