Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sso credential provider to support token provider #4875

Merged
merged 8 commits into from
Jun 15, 2023
58 changes: 44 additions & 14 deletions aws/credentials/ssocreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/auth/bearer"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
Expand Down Expand Up @@ -54,6 +55,22 @@ type Provider struct {

// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string

// The filepath the cached token will be retrieved from. If unset Provider will
// use the startURL to determine the filepath at.
//
// ~/.aws/sso/cache/<sha1-hex-encoded-startURL>.json
//
// If custom cached token filepath is used, the Provider's startUrl
// parameter will be ignored.
CachedTokenFilepath string

// Used by the SSOCredentialProvider to judge if TokenProvider is configured
HasTokenProvider bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? bearer.TokenProvider is an interface, you can check if it's nil as you were before.


// Used by the SSOCredentialProvider if a token configuration
// profile is used in the shared config
TokenProvider bearer.TokenProvider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: id like to discuss this as a team. my understanding is that this can be optional. and my understanding of Go is that the best way to signal optionality is via pointer types? but im wondering if thats only the case for primitive types and not object types?

Copy link
Contributor Author

@wty-Bryant wty-Bryant Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bear.TokenProvider is an interface so it can be nil which I think can mark it is optional? Interface doesn't support pointer receiver afaik

}

// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
Expand Down Expand Up @@ -88,13 +105,31 @@ func (p *Provider) Retrieve() (credentials.Value, error) {
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
tokenFile, err := loadTokenFile(p.StartURL)
if err != nil {
return credentials.Value{}, err
var accessToken *string
if p.HasTokenProvider {
token, err := p.TokenProvider.RetrieveBearerToken(ctx)
if err != nil {
return credentials.Value{}, err
}
accessToken = &token.Value
} else {
if p.CachedTokenFilepath == "" {
cachedTokenFilePath, err := getCachedFilePath(p.StartURL)
if err != nil {
return credentials.Value{}, err
}
p.CachedTokenFilepath = cachedTokenFilePath
}

tokenFile, err := loadTokenFile(p.CachedTokenFilepath)
if err != nil {
return credentials.Value{}, err
}
accessToken = &tokenFile.AccessToken
}

output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccessToken: accessToken,
AccountId: &p.AccountID,
RoleName: &p.RoleName,
})
Expand All @@ -113,13 +148,13 @@ func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Val
}, nil
}

func getCacheFileName(url string) (string, error) {
func getCachedFilePath(startUrl string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
_, err := hash.Write([]byte(startUrl))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
return filepath.Join(defaultCacheLocation(), strings.ToLower(hex.EncodeToString(hash.Sum(nil)))+".json"), nil
}

type token struct {
Expand All @@ -133,13 +168,8 @@ func (t token) Expired() bool {
return nowTime().Round(0).After(time.Time(t.ExpiresAt))
}

func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}

fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
func loadTokenFile(cachedTokenPath string) (t token, err error) {
fileBytes, err := ioutil.ReadFile(cachedTokenPath)
if err != nil {
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
}
Expand Down
116 changes: 107 additions & 9 deletions aws/credentials/ssocreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ package ssocreds

import (
"fmt"
"path/filepath"
"reflect"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/auth/bearer"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/sso"
Expand All @@ -32,6 +34,18 @@ type mockClient struct {
Response func(mockClient) (*sso.GetRoleCredentialsOutput, error)
}

type mockTokenProvider struct {
Response func() (bearer.Token, error)
}

func (p mockTokenProvider) RetrieveBearerToken(ctx aws.Context) (bearer.Token, error) {
if p.Response == nil {
return bearer.Token{}, nil
}

return p.Response()
}

func (m mockClient) GetRoleCredentialsWithContext(ctx aws.Context, params *sso.GetRoleCredentialsInput, _ ...request.Option) (*sso.GetRoleCredentialsOutput, error) {
m.t.Helper()

Expand Down Expand Up @@ -88,11 +102,14 @@ func TestProvider(t *testing.T) {
defer restoreTime()

cases := map[string]struct {
Client mockClient
AccountID string
Region string
RoleName string
StartURL string
Client mockClient
AccountID string
Region string
RoleName string
StartURL string
CachedTokenFilePath string
HasTokenProvider bool
TokenProvider mockTokenProvider

ExpectedErr bool
ExpectedCredentials credentials.Value
Expand Down Expand Up @@ -131,6 +148,84 @@ func TestProvider(t *testing.T) {
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"custom cached token file": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return &sso.GetRoleCredentialsOutput{
RoleCredentials: &sso.RoleCredentials{
AccessKeyId: aws.String("AccessKey"),
SecretAccessKey: aws.String("SecretKey"),
SessionToken: aws.String("SessionToken"),
Expiration: aws.Int64(1611177743123),
},
}, nil
},
},
CachedTokenFilePath: filepath.Join("testdata", "custom_cached_token.json"),
AccountID: "012345678901",
Region: "us-west-2",
RoleName: "TestRole",
StartURL: "ignored value",
ExpectedCredentials: credentials.Value{
AccessKeyID: "AccessKey",
SecretAccessKey: "SecretKey",
SessionToken: "SessionToken",
ProviderName: ProviderName,
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"access token retrieved by token provider": {
Client: mockClient{
ExpectedAccountID: "012345678901",
ExpectedRoleName: "TestRole",
ExpectedAccessToken: "WFsIHZhbHVldGhpcyBpcyBub3QgYSByZ",
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) {
return &sso.GetRoleCredentialsOutput{
RoleCredentials: &sso.RoleCredentials{
AccessKeyId: aws.String("AccessKey"),
SecretAccessKey: aws.String("SecretKey"),
SessionToken: aws.String("SessionToken"),
Expiration: aws.Int64(1611177743123),
},
}, nil
},
},
TokenProvider: mockTokenProvider{
Response: func() (bearer.Token, error) {
return bearer.Token{
Value: "WFsIHZhbHVldGhpcyBpcyBub3QgYSByZ",
}, nil
},
},
HasTokenProvider: true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasnt this field removed in the source code? why is it still in the test data?

AccountID: "012345678901",
Region: "us-west-2",
RoleName: "TestRole",
StartURL: "ignored value",
ExpectedCredentials: credentials.Value{
AccessKeyID: "AccessKey",
SecretAccessKey: "SecretKey",
SessionToken: "SessionToken",
ProviderName: ProviderName,
},
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC),
},
"token provider return error": {
TokenProvider: mockTokenProvider{
Response: func() (bearer.Token, error) {
return bearer.Token{}, fmt.Errorf("mock token provider return error")
},
},
HasTokenProvider: true,
AccountID: "012345678901",
Region: "us-west-2",
RoleName: "TestRole",
StartURL: "ignored value",
ExpectedErr: true,
},
"expired access token": {
StartURL: "https://expired",
ExpectedErr: true,
Expand Down Expand Up @@ -158,10 +253,13 @@ func TestProvider(t *testing.T) {
tt.Client.t = t

provider := &Provider{
Client: tt.Client,
AccountID: tt.AccountID,
RoleName: tt.RoleName,
StartURL: tt.StartURL,
Client: tt.Client,
AccountID: tt.AccountID,
RoleName: tt.RoleName,
StartURL: tt.StartURL,
CachedTokenFilepath: tt.CachedTokenFilePath,
HasTokenProvider: tt.HasTokenProvider,
TokenProvider: tt.TokenProvider,
}

provider.Expiry.CurrentTime = nowTime
Expand Down
4 changes: 4 additions & 0 deletions aws/credentials/ssocreds/testdata/custom_cached_token.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"accessToken": "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH",
"expiresAt": "2021-01-19T23:00:00Z"
}
24 changes: 22 additions & 2 deletions aws/session/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/aws/aws-sdk-go/service/ssooidc"
"github.com/aws/aws-sdk-go/service/sts"
)

Expand All @@ -33,7 +34,7 @@ func resolveCredentials(cfg *aws.Config,

switch {
case len(sessOpts.Profile) != 0:
// User explicitly provided an Profile in the session's configuration
// User explicitly provided a Profile in the session's configuration
// so load that profile from shared config first.
// Github(aws/aws-sdk-go#2727)
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts)
Expand Down Expand Up @@ -173,8 +174,26 @@ func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers req
return nil, err
}

var optFns []func(provider *ssocreds.Provider)
cfgCopy := cfg.Copy()
cfgCopy.Region = &sharedCfg.SSORegion

if sharedCfg.SSOSession != nil {
cfgCopy.Region = &sharedCfg.SSOSession.SSORegion
cachedPath, err := ssocreds.StandardCachedTokenFilepath(sharedCfg.SSOSession.Name)
if err != nil {
return nil, err
}
mySession := Must(NewSession())
oidcClient := ssooidc.New(mySession, cfgCopy)
tokenProvider := ssocreds.NewSSOTokenProvider(oidcClient, cachedPath)
optFns = append(optFns, func(p *ssocreds.Provider) {
p.HasTokenProvider = true
p.TokenProvider = *tokenProvider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we dereferencing here? This doesn't look right. The methods should probably be pointer receivers (looks like they are value receivers currently).

p.CachedTokenFilepath = cachedPath
})
} else {
cfgCopy.Region = &sharedCfg.SSORegion
}

return ssocreds.NewCredentials(
&Session{
Expand All @@ -184,6 +203,7 @@ func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers req
sharedCfg.SSOAccountID,
sharedCfg.SSORoleName,
sharedCfg.SSOStartURL,
optFns...,
), nil
}

Expand Down