Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions core/clients/key_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
if err := c.recreateAccessToken(); err != nil {
return "", fmt.Errorf("get new access token: %w", err)
}

c.tokenMutex.RLock()
accessToken = c.token.AccessToken
c.tokenMutex.RUnlock()

return accessToken, nil
}

Expand Down Expand Up @@ -312,7 +317,11 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) {
func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) {
body := url.Values{}
body.Set("grant_type", grant)
body.Set("assertion", assertion)
if grant == "refresh_token" {
body.Set("refresh_token", assertion)
} else {
body.Set("assertion", assertion)
}
payload := strings.NewReader(body.Encode())
req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload)
if err != nil {
Expand All @@ -335,9 +344,8 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error {
body = []byte{}
}
return &oapierror.GenericOpenAPIError{
StatusCode: res.StatusCode,
Body: body,
ErrorMessage: err.Error(),
StatusCode: res.StatusCode,
Body: body,
}
}
body, err := io.ReadAll(res.Body)
Expand All @@ -357,6 +365,10 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error {
}

func tokenExpired(token string) (bool, error) {
if token == "" {
return true, nil
}

// We can safely use ParseUnverified because we are not authenticating the user at this point.
// We're just checking the expiration time
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
Expand Down
27 changes: 20 additions & 7 deletions core/clients/key_flow_continuous_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,30 @@ type continuousTokenRefresher struct {
//
// To terminate this routine, close the context in refresher.keyFlow.config.BackgroundTokenRefreshContext.
func (refresher *continuousTokenRefresher) continuousRefreshToken() error {
expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp()
if err != nil {
return fmt.Errorf("get access token expiration timestamp: %w", err)
// Compute timestamp where we'll refresh token
// Access token may be empty at this point, we have to check it
var startRefreshTimestamp time.Time

refresher.keyFlow.tokenMutex.RLock()
accessToken := refresher.keyFlow.token.AccessToken
refresher.keyFlow.tokenMutex.RUnlock()
if accessToken == "" {
startRefreshTimestamp = time.Now()
} else {
expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp()
if err != nil {
return fmt.Errorf("get access token expiration timestamp: %w", err)
}
startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration)
}
startRefreshTimestamp := expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration)

for {
err = refresher.waitUntilTimestamp(startRefreshTimestamp)
err := refresher.waitUntilTimestamp(startRefreshTimestamp)
if err != nil {
return err
}

err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err()
err = refresher.keyFlow.config.BackgroundTokenRefreshContext.Err()
if err != nil {
return fmt.Errorf("check context: %w", err)
}
Expand All @@ -78,7 +89,9 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error {
}

func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) {
refresher.keyFlow.tokenMutex.RLock()
token := refresher.keyFlow.token.AccessToken
refresher.keyFlow.tokenMutex.RUnlock()

// We can safely use ParseUnverified because we are not doing authentication of any kind
// We're just checking the expiration time
Expand Down Expand Up @@ -109,7 +122,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim
// - (false, nil) if not successful but should be retried.
// - (_, err) if not successful and shouldn't be retried.
func (refresher *continuousTokenRefresher) refreshToken() (bool, error) {
err := refresher.keyFlow.createAccessTokenWithRefreshToken()
err := refresher.keyFlow.recreateAccessToken()
if err == nil {
return true, nil
}
Expand Down
61 changes: 47 additions & 14 deletions core/clients/key_flow_continuous_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ func TestContinuousRefreshToken(t *testing.T) {

for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create access token: %v", err)
}

refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create refresh token: %v", err)
}

numberDoCalls := 0
mockDo := func(client *http.Client, req *http.Request, cfg *RetryConfig) (resp *http.Response, err error) {
numberDoCalls++
Expand All @@ -93,15 +107,16 @@ func TestContinuousRefreshToken(t *testing.T) {
return nil, tt.doError
}

accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("Do call: failed to create access token: %v", err)
}

responseBodyStruct := TokenResponseBody{
AccessToken: accessToken,
AccessToken: newAccessToken,
RefreshToken: refreshToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
Expand All @@ -114,13 +129,6 @@ func TestContinuousRefreshToken(t *testing.T) {
return response, nil
}

accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create access token: %v", err)
}

ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn)
defer cancel()
Expand All @@ -132,7 +140,8 @@ func TestContinuousRefreshToken(t *testing.T) {
},
doer: mockDo,
token: &TokenResponseBody{
AccessToken: accessToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
}

Expand All @@ -155,7 +164,7 @@ func TestContinuousRefreshToken(t *testing.T) {
}

// Tests if
// - continuousRefreshToken() changes the token
// - continuousRefreshToken() updates access token using the refresh token
// - The access token can be accessed while continuousRefreshToken() is trying to update it
func TestContinuousRefreshTokenConcurrency(t *testing.T) {
// The times here are in the order of miliseconds (so they run faster)
Expand Down Expand Up @@ -203,6 +212,14 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("created tokens are equal")
}

// The refresh token used to update the access token
refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create refresh token: %v", err)
}

ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel() // This cancels the refresher goroutine
Expand Down Expand Up @@ -233,13 +250,28 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase)
}

// Check required fields are passed
err = req.ParseForm()
if err != nil {
t.Fatalf("Do call: failed to parse body form: %v", err)
}
reqGrantType := req.Form.Get("grant_type")
if reqGrantType != "refresh_token" {
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
}
reqRefreshToken := req.Form.Get("refresh_token")
if reqRefreshToken != refreshToken {
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
}

// Return response with accessTokenSecond
responseBodyStruct := TokenResponseBody{
AccessToken: accessTokenSecond,
AccessToken: accessTokenSecond,
RefreshToken: refreshToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
t.Fatalf("Do call: failed to marshal access token response: %v", err)
t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err)
}
response := &http.Response{
StatusCode: http.StatusOK,
Expand Down Expand Up @@ -303,7 +335,8 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
},
doer: mockDo,
token: &TokenResponseBody{
AccessToken: accessTokenFirst,
AccessToken: accessTokenFirst,
RefreshToken: refreshToken,
},
}

Expand Down