-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update sso credential provider to support token provider #4875
Changes from 5 commits
0a6e3a6
894c8cf
ae875c1
175106b
a25a014
9d68ffd
9d605f3
2fb8973
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ import ( | |
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/auth/bearer" | ||
"github.com/aws/aws-sdk-go/aws/awserr" | ||
"github.com/aws/aws-sdk-go/aws/client" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
|
@@ -54,6 +55,22 @@ type Provider struct { | |
|
||
// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal. | ||
StartURL string | ||
|
||
// The filepath the cached token will be retrieved from. If unset Provider will | ||
// use the startURL to determine the filepath at. | ||
// | ||
// ~/.aws/sso/cache/<sha1-hex-encoded-startURL>.json | ||
// | ||
// If custom cached token filepath is used, the Provider's startUrl | ||
// parameter will be ignored. | ||
CachedTokenFilepath string | ||
|
||
// Used by the SSOCredentialProvider to judge if TokenProvider is configured | ||
HasTokenProvider bool | ||
|
||
// Used by the SSOCredentialProvider if a token configuration | ||
// profile is used in the shared config | ||
TokenProvider bearer.TokenProvider | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: id like to discuss this as a team. my understanding is that this can be optional. and my understanding of Go is that the best way to signal optionality is via pointer types? but im wondering if thats only the case for primitive types and not object types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
} | ||
|
||
// NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured | ||
|
@@ -88,13 +105,31 @@ func (p *Provider) Retrieve() (credentials.Value, error) { | |
// RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal | ||
// by exchanging the accessToken present in ~/.aws/sso/cache. | ||
func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { | ||
tokenFile, err := loadTokenFile(p.StartURL) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
var accessToken *string | ||
if p.HasTokenProvider { | ||
token, err := p.TokenProvider.RetrieveBearerToken(ctx) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
} | ||
accessToken = &token.Value | ||
} else { | ||
if p.CachedTokenFilepath == "" { | ||
cachedTokenFilePath, err := getCachedFilePath(p.StartURL) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
} | ||
p.CachedTokenFilepath = cachedTokenFilePath | ||
} | ||
|
||
tokenFile, err := loadTokenFile(p.CachedTokenFilepath) | ||
if err != nil { | ||
return credentials.Value{}, err | ||
} | ||
accessToken = &tokenFile.AccessToken | ||
} | ||
|
||
output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{ | ||
AccessToken: &tokenFile.AccessToken, | ||
AccessToken: accessToken, | ||
AccountId: &p.AccountID, | ||
RoleName: &p.RoleName, | ||
}) | ||
|
@@ -113,13 +148,13 @@ func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Val | |
}, nil | ||
} | ||
|
||
func getCacheFileName(url string) (string, error) { | ||
func getCachedFilePath(startUrl string) (string, error) { | ||
hash := sha1.New() | ||
_, err := hash.Write([]byte(url)) | ||
_, err := hash.Write([]byte(startUrl)) | ||
if err != nil { | ||
return "", err | ||
} | ||
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil | ||
return filepath.Join(defaultCacheLocation(), strings.ToLower(hex.EncodeToString(hash.Sum(nil)))+".json"), nil | ||
} | ||
|
||
type token struct { | ||
|
@@ -133,13 +168,8 @@ func (t token) Expired() bool { | |
return nowTime().Round(0).After(time.Time(t.ExpiresAt)) | ||
} | ||
|
||
func loadTokenFile(startURL string) (t token, err error) { | ||
key, err := getCacheFileName(startURL) | ||
if err != nil { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) | ||
} | ||
|
||
fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key)) | ||
func loadTokenFile(cachedTokenPath string) (t token, err error) { | ||
fileBytes, err := ioutil.ReadFile(cachedTokenPath) | ||
if err != nil { | ||
return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,13 @@ package ssocreds | |
|
||
import ( | ||
"fmt" | ||
"path/filepath" | ||
"reflect" | ||
"testing" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/aws/aws-sdk-go/aws/auth/bearer" | ||
"github.com/aws/aws-sdk-go/aws/credentials" | ||
"github.com/aws/aws-sdk-go/aws/request" | ||
"github.com/aws/aws-sdk-go/service/sso" | ||
|
@@ -32,6 +34,18 @@ type mockClient struct { | |
Response func(mockClient) (*sso.GetRoleCredentialsOutput, error) | ||
} | ||
|
||
type mockTokenProvider struct { | ||
Response func() (bearer.Token, error) | ||
} | ||
|
||
func (p mockTokenProvider) RetrieveBearerToken(ctx aws.Context) (bearer.Token, error) { | ||
if p.Response == nil { | ||
return bearer.Token{}, nil | ||
} | ||
|
||
return p.Response() | ||
} | ||
|
||
func (m mockClient) GetRoleCredentialsWithContext(ctx aws.Context, params *sso.GetRoleCredentialsInput, _ ...request.Option) (*sso.GetRoleCredentialsOutput, error) { | ||
m.t.Helper() | ||
|
||
|
@@ -88,11 +102,14 @@ func TestProvider(t *testing.T) { | |
defer restoreTime() | ||
|
||
cases := map[string]struct { | ||
Client mockClient | ||
AccountID string | ||
Region string | ||
RoleName string | ||
StartURL string | ||
Client mockClient | ||
AccountID string | ||
Region string | ||
RoleName string | ||
StartURL string | ||
CachedTokenFilePath string | ||
HasTokenProvider bool | ||
TokenProvider mockTokenProvider | ||
|
||
ExpectedErr bool | ||
ExpectedCredentials credentials.Value | ||
|
@@ -131,6 +148,84 @@ func TestProvider(t *testing.T) { | |
}, | ||
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC), | ||
}, | ||
"custom cached token file": { | ||
Client: mockClient{ | ||
ExpectedAccountID: "012345678901", | ||
ExpectedRoleName: "TestRole", | ||
ExpectedAccessToken: "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH", | ||
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) { | ||
return &sso.GetRoleCredentialsOutput{ | ||
RoleCredentials: &sso.RoleCredentials{ | ||
AccessKeyId: aws.String("AccessKey"), | ||
SecretAccessKey: aws.String("SecretKey"), | ||
SessionToken: aws.String("SessionToken"), | ||
Expiration: aws.Int64(1611177743123), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
CachedTokenFilePath: filepath.Join("testdata", "custom_cached_token.json"), | ||
AccountID: "012345678901", | ||
Region: "us-west-2", | ||
RoleName: "TestRole", | ||
StartURL: "ignored value", | ||
ExpectedCredentials: credentials.Value{ | ||
AccessKeyID: "AccessKey", | ||
SecretAccessKey: "SecretKey", | ||
SessionToken: "SessionToken", | ||
ProviderName: ProviderName, | ||
}, | ||
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC), | ||
}, | ||
"access token retrieved by token provider": { | ||
Client: mockClient{ | ||
ExpectedAccountID: "012345678901", | ||
ExpectedRoleName: "TestRole", | ||
ExpectedAccessToken: "WFsIHZhbHVldGhpcyBpcyBub3QgYSByZ", | ||
Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) { | ||
return &sso.GetRoleCredentialsOutput{ | ||
RoleCredentials: &sso.RoleCredentials{ | ||
AccessKeyId: aws.String("AccessKey"), | ||
SecretAccessKey: aws.String("SecretKey"), | ||
SessionToken: aws.String("SessionToken"), | ||
Expiration: aws.Int64(1611177743123), | ||
}, | ||
}, nil | ||
}, | ||
}, | ||
TokenProvider: mockTokenProvider{ | ||
Response: func() (bearer.Token, error) { | ||
return bearer.Token{ | ||
Value: "WFsIHZhbHVldGhpcyBpcyBub3QgYSByZ", | ||
}, nil | ||
}, | ||
}, | ||
HasTokenProvider: true, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wasnt this field removed in the source code? why is it still in the test data? |
||
AccountID: "012345678901", | ||
Region: "us-west-2", | ||
RoleName: "TestRole", | ||
StartURL: "ignored value", | ||
ExpectedCredentials: credentials.Value{ | ||
AccessKeyID: "AccessKey", | ||
SecretAccessKey: "SecretKey", | ||
SessionToken: "SessionToken", | ||
ProviderName: ProviderName, | ||
}, | ||
ExpectedExpire: time.Date(2021, 01, 20, 21, 22, 23, 0.123e9, time.UTC), | ||
}, | ||
"token provider return error": { | ||
TokenProvider: mockTokenProvider{ | ||
Response: func() (bearer.Token, error) { | ||
return bearer.Token{}, fmt.Errorf("mock token provider return error") | ||
}, | ||
}, | ||
HasTokenProvider: true, | ||
AccountID: "012345678901", | ||
Region: "us-west-2", | ||
RoleName: "TestRole", | ||
StartURL: "ignored value", | ||
ExpectedErr: true, | ||
}, | ||
"expired access token": { | ||
StartURL: "https://expired", | ||
ExpectedErr: true, | ||
|
@@ -158,10 +253,13 @@ func TestProvider(t *testing.T) { | |
tt.Client.t = t | ||
|
||
provider := &Provider{ | ||
Client: tt.Client, | ||
AccountID: tt.AccountID, | ||
RoleName: tt.RoleName, | ||
StartURL: tt.StartURL, | ||
Client: tt.Client, | ||
AccountID: tt.AccountID, | ||
RoleName: tt.RoleName, | ||
StartURL: tt.StartURL, | ||
CachedTokenFilepath: tt.CachedTokenFilePath, | ||
HasTokenProvider: tt.HasTokenProvider, | ||
TokenProvider: tt.TokenProvider, | ||
} | ||
|
||
provider.Expiry.CurrentTime = nowTime | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"accessToken": "ZhbHVldGhpcyBpcyBub3QgYSByZWFsIH", | ||
"expiresAt": "2021-01-19T23:00:00Z" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ import ( | |
"github.com/aws/aws-sdk-go/aws/defaults" | ||
"github.com/aws/aws-sdk-go/aws/request" | ||
"github.com/aws/aws-sdk-go/internal/shareddefaults" | ||
"github.com/aws/aws-sdk-go/service/ssooidc" | ||
"github.com/aws/aws-sdk-go/service/sts" | ||
) | ||
|
||
|
@@ -33,7 +34,7 @@ func resolveCredentials(cfg *aws.Config, | |
|
||
switch { | ||
case len(sessOpts.Profile) != 0: | ||
// User explicitly provided an Profile in the session's configuration | ||
// User explicitly provided a Profile in the session's configuration | ||
// so load that profile from shared config first. | ||
// Github(aws/aws-sdk-go#2727) | ||
return resolveCredsFromProfile(cfg, envCfg, sharedCfg, handlers, sessOpts) | ||
|
@@ -173,8 +174,26 @@ func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers req | |
return nil, err | ||
} | ||
|
||
var optFns []func(provider *ssocreds.Provider) | ||
cfgCopy := cfg.Copy() | ||
cfgCopy.Region = &sharedCfg.SSORegion | ||
|
||
if sharedCfg.SSOSession != nil { | ||
cfgCopy.Region = &sharedCfg.SSOSession.SSORegion | ||
cachedPath, err := ssocreds.StandardCachedTokenFilepath(sharedCfg.SSOSession.Name) | ||
if err != nil { | ||
return nil, err | ||
} | ||
mySession := Must(NewSession()) | ||
oidcClient := ssooidc.New(mySession, cfgCopy) | ||
tokenProvider := ssocreds.NewSSOTokenProvider(oidcClient, cachedPath) | ||
optFns = append(optFns, func(p *ssocreds.Provider) { | ||
p.HasTokenProvider = true | ||
p.TokenProvider = *tokenProvider | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we dereferencing here? This doesn't look right. The methods should probably be pointer receivers (looks like they are value receivers currently). |
||
p.CachedTokenFilepath = cachedPath | ||
}) | ||
} else { | ||
cfgCopy.Region = &sharedCfg.SSORegion | ||
} | ||
|
||
return ssocreds.NewCredentials( | ||
&Session{ | ||
|
@@ -184,6 +203,7 @@ func resolveSSOCredentials(cfg *aws.Config, sharedCfg sharedConfig, handlers req | |
sharedCfg.SSOAccountID, | ||
sharedCfg.SSORoleName, | ||
sharedCfg.SSOStartURL, | ||
optFns..., | ||
), nil | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
bearer.TokenProvider
is an interface, you can check if it'snil
as you were before.