Skip to content

Commit

Permalink
Added accountid to AWSClient, set it during initialization phase, and…
Browse files Browse the repository at this point in the history
… use it for ARN building

We also now use sts:GetCallerIdentity in the GetAccountId func
  • Loading branch information
bigkraig committed Apr 28, 2016
1 parent 30acaf3 commit 5e122e5
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 65 deletions.
17 changes: 13 additions & 4 deletions builtin/providers/aws/auth_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-cleanhttp"
)

func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
func GetAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
// If we have creds from instance profile, we can use metadata API
if authProviderName == ec2rolecreds.ProviderName {
log.Println("[DEBUG] Trying to get account ID via AWS Metadata API")
Expand All @@ -42,16 +43,24 @@ func GetAccountId(iamconn *iam.IAM, authProviderName string) (string, error) {
return parseAccountIdFromArn(*outUser.User.Arn)
}

// Then try IAM ListRoles
awsErr, ok := err.(awserr.Error)
// AccessDenied and ValidationError can be raised
// if credentials belong to federated profile, so we ignore these
if !ok || (awsErr.Code() != "AccessDenied" && awsErr.Code() != "ValidationError") {
return "", fmt.Errorf("Failed getting account ID via 'iam:GetUser': %s", err)
}

log.Printf("[DEBUG] Getting account ID via iam:GetUser failed: %s", err)
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles instead")

// Then try STS GetCallerIdentity
log.Println("[DEBUG] Trying to get account ID via sts:GetCallerIdentity")
outCallerIdentity, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return *outCallerIdentity.Account, nil
}
log.Printf("[DEBUG] Getting account ID via sts:GetCallerIdentity failed: %s", err)

// Then try IAM ListRoles
log.Println("[DEBUG] Trying to get account ID via iam:ListRoles")
outRoles, err := iamconn.ListRoles(&iam.ListRolesInput{
MaxItems: aws.Int64(int64(1)),
})
Expand Down
35 changes: 18 additions & 17 deletions builtin/providers/aws/auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
)

func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
Expand All @@ -28,10 +29,10 @@ func TestAWSGetAccountId_shouldBeValid_fromEC2Role(t *testing.T) {
defer awsTs()

iamEndpoints := []*iamEndpoint{}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
Expand All @@ -55,10 +56,10 @@ func TestAWSGetAccountId_shouldBeValid_EC2RoleHasPriority(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, ec2rolecreds.ProviderName)
id, err := GetAccountId(iamConn, stsConn, ec2rolecreds.ProviderName)
if err != nil {
t.Fatalf("Getting account ID from EC2 metadata API failed: %s", err)
}
Expand All @@ -76,10 +77,10 @@ func TestAWSGetAccountId_shouldBeValid_fromIamUser(t *testing.T) {
Response: &iamResponse{200, iamResponse_GetUser_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via GetUser failed: %s", err)
}
Expand All @@ -101,10 +102,10 @@ func TestAWSGetAccountId_shouldBeValid_fromIamListRoles(t *testing.T) {
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
Expand All @@ -126,10 +127,10 @@ func TestAWSGetAccountId_shouldBeValid_federatedRole(t *testing.T) {
Response: &iamResponse{200, iamResponse_ListRoles_valid, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err != nil {
t.Fatalf("Getting account ID via ListRoles failed: %s", err)
}
Expand All @@ -151,10 +152,10 @@ func TestAWSGetAccountId_shouldError_unauthorizedFromIam(t *testing.T) {
Response: &iamResponse{403, iamResponse_ListRoles_unauthorized, "text/xml"},
},
}
ts, iamConn := getMockedAwsIamApi(iamEndpoints)
ts, iamConn, stsConn := getMockedAwsIamStsApi(iamEndpoints)
defer ts()

id, err := GetAccountId(iamConn, "")
id, err := GetAccountId(iamConn, stsConn, "")
if err == nil {
t.Fatal("Expected error when getting account ID")
}
Expand Down Expand Up @@ -586,9 +587,9 @@ func invalidAwsEnv(t *testing.T) func() {
return ts.Close
}

// getMockedAwsIamApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM server
func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
// getMockedAwsIamStsApi establishes a httptest server to simulate behaviour
// of a real AWS' IAM & STS server
func getMockedAwsIamStsApi(endpoints []*iamEndpoint) (func(), *iam.IAM, *sts.STS) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf := new(bytes.Buffer)
buf.ReadFrom(r.Body)
Expand Down Expand Up @@ -624,8 +625,8 @@ func getMockedAwsIamApi(endpoints []*iamEndpoint) (func(), *iam.IAM) {
CredentialsChainVerboseErrors: aws.Bool(true),
})
iamConn := iam.New(sess)

return ts.Close, iamConn
stsConn := sts.New(sess)
return ts.Close, iamConn, stsConn
}

func getEnv() *currentEnv {
Expand Down
20 changes: 11 additions & 9 deletions builtin/providers/aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ type AWSClient struct {
stsconn *sts.STS
redshiftconn *redshift.Redshift
r53conn *route53.Route53
accountid string
region string
rdsconn *rds.RDS
iamconn *iam.IAM
Expand Down Expand Up @@ -220,10 +221,11 @@ func (c *Config) Client() (interface{}, error) {
log.Println("[INFO] Initializing Elastic Beanstalk Connection")
client.elasticbeanstalkconn = elasticbeanstalk.New(sess)

authErr := c.ValidateAccountId(client.iamconn, cp.ProviderName)
account_id, authErr := c.ValidateAccountId(client.iamconn, client.stsconn, cp.ProviderName)
if authErr != nil {
errs = append(errs, authErr)
}
client.accountid = account_id

log.Println("[INFO] Initializing Kinesis Firehose Connection")
client.firehoseconn = firehose.New(sess)
Expand Down Expand Up @@ -343,35 +345,35 @@ func (c *Config) ValidateCredentials(iamconn *iam.IAM) error {

// ValidateAccountId returns a context-specific error if the configured account
// id is explicitly forbidden or not authorised; and nil if it is authorised.
func (c *Config) ValidateAccountId(iamconn *iam.IAM, authProviderName string) error {
func (c *Config) ValidateAccountId(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, error) {
if c.AllowedAccountIds == nil && c.ForbiddenAccountIds == nil {
return nil
return "", nil
}

log.Printf("[INFO] Validating account ID")
account_id, err := GetAccountId(iamconn, authProviderName)
account_id, err := GetAccountId(iamconn, stsconn, authProviderName)
if err != nil {
return err
return "", err
}

if c.ForbiddenAccountIds != nil {
for _, id := range c.ForbiddenAccountIds {
if id == account_id {
return fmt.Errorf("Forbidden account ID (%s)", id)
return "", fmt.Errorf("Forbidden account ID (%s)", id)
}
}
}

if c.AllowedAccountIds != nil {
for _, id := range c.AllowedAccountIds {
if id == account_id {
return nil
return account_id, nil
}
}
return fmt.Errorf("Account ID not allowed (%s)", account_id)
return "", fmt.Errorf("Account ID not allowed (%s)", account_id)
}

return nil
return account_id, nil
}

// addTerraformVersionToUserAgent is a named handler that will add Terraform's
Expand Down
8 changes: 1 addition & 7 deletions builtin/providers/aws/resource_aws_db_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/sts"

"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/helper/schema"
Expand Down Expand Up @@ -973,14 +972,9 @@ func resourceAwsDbInstanceStateRefreshFunc(
}

func buildRDSARN(identifier string, meta interface{}) (string, error) {
stsconn := meta.(*AWSClient).stsconn
region := meta.(*AWSClient).region
accountID := meta.(*AWSClient).accountid

resp, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}
accountID := *resp.Account
arn := fmt.Sprintf("arn:aws:rds:%s:%s:db:%s", region, accountID, identifier)
return arn, nil
}
8 changes: 1 addition & 7 deletions builtin/providers/aws/resource_aws_db_parameter_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/sts"
)

func resourceAwsDbParameterGroup() *schema.Resource {
Expand Down Expand Up @@ -272,14 +271,9 @@ func resourceAwsDbParameterHash(v interface{}) int {
}

func buildRDSPGARN(d *schema.ResourceData, meta interface{}) (string, error) {
stsconn := meta.(*AWSClient).stsconn
region := meta.(*AWSClient).region
accountID := meta.(*AWSClient).accountid

resp, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}
accountID := *resp.Account
arn := fmt.Sprintf("arn:aws:rds:%s:%s:pg:%s", region, accountID, d.Id())
return arn, nil
}
8 changes: 1 addition & 7 deletions builtin/providers/aws/resource_aws_db_security_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/terraform/helper/hashcode"
"github.com/hashicorp/terraform/helper/resource"
Expand Down Expand Up @@ -344,14 +343,9 @@ func resourceAwsDbSecurityGroupStateRefreshFunc(
}

func buildRDSSecurityGroupARN(d *schema.ResourceData, meta interface{}) (string, error) {
stsconn := meta.(*AWSClient).stsconn
region := meta.(*AWSClient).region
accountID := meta.(*AWSClient).accountid

resp, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}
accountID := *resp.Account
arn := fmt.Sprintf("arn:aws:rds:%s:%s:secgrp:%s", region, accountID, d.Id())
return arn, nil
}
8 changes: 1 addition & 7 deletions builtin/providers/aws/resource_aws_db_subnet_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/helper/schema"
)
Expand Down Expand Up @@ -225,14 +224,9 @@ func resourceAwsDbSubnetGroupDeleteRefreshFunc(
}

func buildRDSsubgrpARN(d *schema.ResourceData, meta interface{}) (string, error) {
stsconn := meta.(*AWSClient).stsconn
region := meta.(*AWSClient).region
accountID := meta.(*AWSClient).accountid

resp, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}
accountID := *resp.Account
arn := fmt.Sprintf("arn:aws:rds:%s:%s:subgrp:%s", region, accountID, d.Id())
return arn, nil
}
Expand Down
8 changes: 1 addition & 7 deletions builtin/providers/aws/resource_aws_elasticache_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/helper/schema"
)
Expand Down Expand Up @@ -620,14 +619,9 @@ func cacheClusterStateRefreshFunc(conn *elasticache.ElastiCache, clusterID, give
}

func buildECARN(d *schema.ResourceData, meta interface{}) (string, error) {
stsconn := meta.(*AWSClient).stsconn
region := meta.(*AWSClient).region
accountID := meta.(*AWSClient).accountid

resp, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", err
}
accountID := *resp.Account
arn := fmt.Sprintf("arn:aws:elasticache:%s:%s:cluster:%s", region, accountID, d.Id())
return arn, nil
}

0 comments on commit 5e122e5

Please sign in to comment.