Skip to content

Commit

Permalink
Add sourceArn to sts through headers
Browse files Browse the repository at this point in the history
  • Loading branch information
haoranleo committed Aug 26, 2024
1 parent 637fcb7 commit 026a356
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 5 deletions.
1 change: 1 addition & 0 deletions cmd/aws-iam-authenticator/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func getConfig() (config.Config, error) {
PartitionID: viper.GetString("server.partition"),
ClusterID: viper.GetString("clusterID"),
ServerEC2DescribeInstancesRoleARN: viper.GetString("server.ec2DescribeInstancesRoleARN"),
SourceARN: viper.GetString("server.sourceARN"),
HostPort: viper.GetInt("server.port"),
Hostname: viper.GetString("server.hostname"),
GenerateKubeconfigPath: viper.GetString("server.generateKubeconfig"),
Expand Down
6 changes: 6 additions & 0 deletions pkg/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ type Config struct {
// running.
ServerEC2DescribeInstancesRoleARN string

// SourceARN is value which is passed while assuming role specified by ServerEC2DescribeInstancesRoleARN.
// When a service assumes a role in your account, you can include the aws:SourceAccount and aws:SourceArn global
// condition context keys in your role trust policy to limit access to the role to only requests that are generated
// by expected resources. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
SourceARN string

// Address defines the hostname or IP Address to bind the HTTPS server to listen to. This is useful when creating
// a local server to handle the authentication request for development.
Address string
Expand Down
54 changes: 50 additions & 4 deletions pkg/ec2provider/ec2provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/endpoints"
Expand Down Expand Up @@ -35,6 +36,11 @@ const (
// Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has
// already become 100 then it will not respect this limit
maxWaitIntervalForBatch = 200

// Headers for STS request for source ARN
headerSourceArn = "x-amz-source-arn"
// Headers for STS request for source account
headerSourceAccount = "x-amz-source-account"
)

// Get a node name from instance ID
Expand All @@ -60,7 +66,7 @@ type ec2ProviderImpl struct {
instanceIdsChannel chan string
}

func New(roleARN, region string, qps int, burst int) EC2Provider {
func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider {
dnsCache := ec2PrivateDNSCache{
cache: make(map[string]string),
lock: sync.RWMutex{},
Expand All @@ -70,7 +76,7 @@ func New(roleARN, region string, qps int, burst int) EC2Provider {
lock: sync.RWMutex{},
}
return &ec2ProviderImpl{
ec2: ec2.New(newSession(roleARN, region, qps, burst)),
ec2: ec2.New(newSession(roleARN, sourceARN, region, qps, burst)),
privateDNSCache: dnsCache,
ec2Requests: ec2Requests,
instanceIdsChannel: make(chan string, maxChannelSize),
Expand All @@ -81,7 +87,7 @@ func New(roleARN, region string, qps int, burst int) EC2Provider {
// the environment, shared credentials (~/.aws/credentials), or EC2 Instance
// Role.

func newSession(roleARN, region string, qps int, burst int) *session.Session {
func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.Session {
sess := session.Must(session.NewSession())
sess.Handlers.Build.PushFrontNamed(request.NamedHandler{
Name: "authenticatorUserAgent",
Expand All @@ -103,8 +109,10 @@ func newSession(roleARN, region string, qps int, burst int) *session.Session {
logrus.Errorf("Getting error = %s while creating rate limited client ", err)
}

stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)),
roleARN, sourceARN)
ap := &stscreds.AssumeRoleProvider{
Client: sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)),
Client: stsClient,
RoleARN: roleARN,
Duration: time.Duration(60) * time.Minute,
}
Expand Down Expand Up @@ -277,3 +285,41 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string
p.unsetRequestInFlightForInstanceId(id)
}
}

func applySTSRequestHeaders(stsClient *sts.STS, roleARN, sourceARN string) *sts.STS {
logrus.Infof("Using AWS assumed role %v", roleARN)
sourceAcct, err := getSourceAccount(roleARN)
if err != nil {
logrus.Errorf("failed to parse source account from role ARN %v: %v", roleARN, err)
return stsClient
}
reqHeaders := map[string]string{
headerSourceAccount: sourceAcct,
}
if sourceARN != "" {
reqHeaders[headerSourceArn] = sourceARN
}
stsClient.Handlers.Sign.PushFront(func(s *request.Request) {
s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders))
})
logrus.Infof("configuring STS client with extra headers, %v", reqHeaders)
return stsClient
}

// getSourceAccount constructs source acct and return them for use
func getSourceAccount(roleARN string) (string, error) {
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
// arn:partition:service:region:account-id:resource-type/resource-id
// IAM format, region is always blank
// arn:aws:iam::account:role/role-name-with-path
if !arn.IsARN(roleARN) {
return "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
}

parsedArn, err := arn.Parse(roleARN)
if err != nil {
return "", err
}

return parsedArn.AccountID, nil
}
41 changes: 41 additions & 0 deletions pkg/ec2provider/ec2provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,44 @@ func prepare100InstanceOutput() []*ec2.Reservation {
return reservations

}

func TestGetSourceAcctAndArn(t *testing.T) {
type args struct {
roleARN string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "corect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876:role/test-cluster",
},
want: "123456789876",
wantErr: false,
},
{
name: "incorect role arn",
args: args{
roleARN: "arn:aws:iam::123456789876",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getSourceAccount(tt.args.roleARN)
if (err != nil) != tt.wantErr {
t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2

h := &handler{
verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion),
ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst),
ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst),
clusterID: c.ClusterID,
backendMapper: backendMapper,
scrubbedAccounts: c.Config.ScrubbedAWSAccounts,
Expand Down

0 comments on commit 026a356

Please sign in to comment.