Skip to content

Commit

Permalink
add identity token fetcher
Browse files Browse the repository at this point in the history
  • Loading branch information
vinay-gopalan committed Jan 22, 2024
1 parent 7a5e901 commit c4c6ec7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
27 changes: 27 additions & 0 deletions awsutil/generate_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type CredentialsConfig struct {
// identity token provider
WebIdentityToken string

// The web identity token fetcher to use with the web identity token provider
WebIdentityTokenFetcher stscreds.TokenFetcher

// The http.Client to use, or nil for the client to use its default
HTTPClient *http.Client

Expand Down Expand Up @@ -134,6 +137,8 @@ func NewCredentialsConfig(opt ...Option) (*CredentialsConfig, error) {
}
c.WebIdentityToken = opts.withWebIdentityToken

c.WebIdentityTokenFetcher = opts.withWebIdentityTokenFetcher

if c.RoleARN == "" {
if c.RoleSessionName != "" {
return nil, fmt.Errorf("role session name specified without role ARN")
Expand Down Expand Up @@ -265,6 +270,28 @@ func (c *CredentialsConfig) GenerateCredentialChain(opt ...Option) (*credentials
}
webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, FetchTokenContents(c.WebIdentityToken))

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
// generating credentials to check validity first
providers = append(providers, webIdentityProvider)
} else {
// Check if the webIdentityProvider can successfully retrieve
// credentials (via sts:AssumeRole), and warn if there's a problem.
if _, err := webIdentityProvider.Retrieve(); err != nil {
c.log(hclog.Warn, "error assuming role with WebIdentityToken", "roleARN", roleARN, "sessionName", roleSessionName, "err", err)
} else {
// Add the web identity role credential provider
providers = append(providers, webIdentityProvider)
}
}
} else if c.WebIdentityTokenFetcher != nil {
c.log(hclog.Debug, "adding web identity provider with token fetcher", "roleARN", roleARN)
sess, err := session.NewSession()
if err != nil {
return nil, errors.Wrap(err, "error creating a new session to create a WebIdentityRoleProvider with token fetcher")
}
webIdentityProvider := stscreds.NewWebIdentityRoleProviderWithToken(sts.New(sess), roleARN, roleSessionName, c.WebIdentityTokenFetcher)

if opts.withSkipWebIdentityValidity {
// Add the web identity role credential provider without
// generating credentials to check validity first
Expand Down
12 changes: 12 additions & 0 deletions awsutil/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"time"

"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/hashicorp/go-hclog"
)
Expand Down Expand Up @@ -50,6 +51,7 @@ type options struct {
withWebIdentityTokenFile string
withWebIdentityToken string
withSkipWebIdentityValidity bool
withWebIdentityTokenFetcher stscreds.TokenFetcher
withHttpClient *http.Client
withValidityCheckTimeout time.Duration
withIAMAPIFunc IAMAPIFunc
Expand Down Expand Up @@ -124,6 +126,16 @@ func WithWebIdentityToken(with string) Option {
}
}

// WithWebIdentityTokenFetcher allows passing an STS TokenFetcher which
// allows the AWS SDK client automatically to refresh the web identity token
// from any source.
func WithWebIdentityTokenFetcher(with stscreds.TokenFetcher) Option {
return func(o *options) error {
o.withWebIdentityTokenFetcher = with
return nil
}
}

// WithSkipWebIdentityValidity allows controlling whether the validity check is
// skipped for the web identity provider
func WithSkipWebIdentityValidity(with bool) Option {
Expand Down
14 changes: 14 additions & 0 deletions awsutil/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ func Test_GetOpts(t *testing.T) {
testOpts.withWebIdentityToken = "foo"
assert.Equal(t, opts, testOpts)
})
t.Run("WithWebIdentityTokenFetcher", func(t *testing.T) {
f := testFetcher{}
opts, err := getOpts(WithWebIdentityTokenFetcher(f))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.withWebIdentityTokenFetcher = f
assert.Equal(t, opts, testOpts)
})
t.Run("WithSkipWebIdentityValidity", func(t *testing.T) {
opts, err := getOpts(WithSkipWebIdentityValidity(true))
require.NoError(t, err)
Expand All @@ -185,3 +193,9 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(t, opts, testOpts)
})
}

type testFetcher struct{}

func (testFetcher) FetchToken(_ aws.Context) ([]byte, error) {
return nil, nil
}

0 comments on commit c4c6ec7

Please sign in to comment.