diff --git a/cmd/aws-iam-authenticator/root.go b/cmd/aws-iam-authenticator/root.go index bad3e2d1e..77b33f509 100644 --- a/cmd/aws-iam-authenticator/root.go +++ b/cmd/aws-iam-authenticator/root.go @@ -91,6 +91,7 @@ func getConfig() (config.Config, error) { PartitionID: viper.GetString("server.partition"), ClusterID: viper.GetString("clusterID"), ServerEC2DescribeInstancesRoleARN: viper.GetString("server.ec2DescribeInstancesRoleARN"), + SourceARN: viper.GetString("server.sourceARN"), HostPort: viper.GetInt("server.port"), Hostname: viper.GetString("server.hostname"), GenerateKubeconfigPath: viper.GetString("server.generateKubeconfig"), diff --git a/pkg/config/types.go b/pkg/config/types.go index 0b5e57486..5d2a9818e 100644 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -128,6 +128,12 @@ type Config struct { // running. ServerEC2DescribeInstancesRoleARN string + // SourceARN is value which is passed while assuming role specified by ServerEC2DescribeInstancesRoleARN. + // When a service assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global + // condition context keys in your role trust policy to limit access to the role to only requests that are generated + // by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + SourceARN string + // Address defines the hostname or IP Address to bind the HTTPS server to listen to. This is useful when creating // a local server to handle the authentication request for development. Address string diff --git a/pkg/ec2provider/ec2provider.go b/pkg/ec2provider/ec2provider.go index d661fe3e4..8dbd66265 100644 --- a/pkg/ec2provider/ec2provider.go +++ b/pkg/ec2provider/ec2provider.go @@ -7,6 +7,7 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/endpoints" @@ -35,6 +36,11 @@ const ( // Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has // already become 100 then it will not respect this limit maxWaitIntervalForBatch = 200 + + // Headers for STS request for source ARN + headerSourceArn = "x-amz-source-arn" + // Headers for STS request for source account + headerSourceAccount = "x-amz-source-account" ) // Get a node name from instance ID @@ -60,7 +66,7 @@ type ec2ProviderImpl struct { instanceIdsChannel chan string } -func New(roleARN, region string, qps int, burst int) EC2Provider { +func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { dnsCache := ec2PrivateDNSCache{ cache: make(map[string]string), lock: sync.RWMutex{}, @@ -70,7 +76,7 @@ func New(roleARN, region string, qps int, burst int) EC2Provider { lock: sync.RWMutex{}, } return &ec2ProviderImpl{ - ec2: ec2.New(newSession(roleARN, region, qps, burst)), + ec2: ec2.New(newSession(roleARN, sourceARN, region, qps, burst)), privateDNSCache: dnsCache, ec2Requests: ec2Requests, instanceIdsChannel: make(chan string, maxChannelSize), @@ -81,7 +87,7 @@ func New(roleARN, region string, qps int, burst int) EC2Provider { // the environment, shared credentials (~/.aws/credentials), or EC2 Instance // Role. -func newSession(roleARN, region string, qps int, burst int) *session.Session { +func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.Session { sess := session.Must(session.NewSession()) sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ Name: "authenticatorUserAgent", @@ -103,8 +109,10 @@ func newSession(roleARN, region string, qps int, burst int) *session.Session { logrus.Errorf("Getting error = %s while creating rate limited client ", err) } + stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), + roleARN, sourceARN) ap := &stscreds.AssumeRoleProvider{ - Client: sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), + Client: stsClient, RoleARN: roleARN, Duration: time.Duration(60) * time.Minute, } @@ -277,3 +285,41 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string p.unsetRequestInFlightForInstanceId(id) } } + +func applySTSRequestHeaders(stsClient *sts.STS, roleARN, sourceARN string) *sts.STS { + logrus.Infof("Using AWS assumed role %v", roleARN) + sourceAcct, err := getSourceAccount(roleARN) + if err != nil { + logrus.Errorf("failed to parse source account from role ARN %v: %v", roleARN, err) + return stsClient + } + reqHeaders := map[string]string{ + headerSourceAccount: sourceAcct, + } + if sourceARN != "" { + reqHeaders[headerSourceArn] = sourceARN + } + stsClient.Handlers.Sign.PushFront(func(s *request.Request) { + s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) + }) + logrus.Infof("configuring STS client with extra headers, %v", reqHeaders) + return stsClient +} + +// getSourceAccount constructs source acct and return them for use +func getSourceAccount(roleARN string) (string, error) { + // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) + // arn:partition:service:region:account-id:resource-type/resource-id + // IAM format, region is always blank + // arn:aws:iam::account:role/role-name-with-path + if !arn.IsARN(roleARN) { + return "", fmt.Errorf("incorrect ARN format for role %s", roleARN) + } + + parsedArn, err := arn.Parse(roleARN) + if err != nil { + return "", err + } + + return parsedArn.AccountID, nil +} diff --git a/pkg/ec2provider/ec2provider_test.go b/pkg/ec2provider/ec2provider_test.go index 35a984697..912d73c8c 100644 --- a/pkg/ec2provider/ec2provider_test.go +++ b/pkg/ec2provider/ec2provider_test.go @@ -150,3 +150,44 @@ func prepare100InstanceOutput() []*ec2.Reservation { return reservations } + +func TestGetSourceAcctAndArn(t *testing.T) { + type args struct { + roleARN string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "corect role arn", + args: args{ + roleARN: "arn:aws:iam::123456789876:role/test-cluster", + }, + want: "123456789876", + wantErr: false, + }, + { + name: "incorect role arn", + args: args{ + roleARN: "arn:aws:iam::123456789876", + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getSourceAccount(tt.args.roleARN) + if (err != nil) != tt.wantErr { + t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index b46007e01..3057cb8a0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -199,7 +199,7 @@ func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2 h := &handler{ verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion), - ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst), + ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst), clusterID: c.ClusterID, backendMapper: backendMapper, scrubbedAccounts: c.Config.ScrubbedAWSAccounts,