Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- `core`: [v0.y.z]
- **Feature:** Add package `runtime`, which implements methods to be used when performing API requests.
- **Feature:** Add method `WithCaptureHTTPResponse` to package `runtime`, which does the same as `config.WithCaptureHTTPResponse`. Method was moved to avoid confusion due to it not being a configuration option, and will be removed in a later release.
- **Feature:** Add configuration option that, for the key flow, enables a goroutine to be spawned that will refresh the access token when it's close to expiring
- **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead.
- **Deprecation:** Mark method `config.WithJWKSEndpoint` and field `config.Configuration.JWKSCustomUrl` as deprecated. Validation using JWKS was removed, for being redundant with token validation done in the APIs. These have no effect.
- **Breaking Change:** Remove method `KeyFlow.Clone`, that was no longer being used.
Expand Down
4 changes: 4 additions & 0 deletions core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v0.10.0 (YYYY-MM-DD)

- **Feature:** Add configuration option that, for the key flow, enables a goroutine to be spawned that will refresh the access token when it's close to expiring

## v0.9.0 (2024-02-19)

- **Deprecation:** Mark method `config.WithCaptureHTTPResponse` as deprecated, to avoid confusion due to it not being a configuration option. Use `runtime.WithCaptureHTTPResponse` instead.
Expand Down
9 changes: 5 additions & 4 deletions core/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,11 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) {
}

keyCfg := clients.KeyFlowConfig{
ServiceAccountKey: serviceAccountKey,
PrivateKey: cfg.PrivateKey,
ClientRetry: cfg.RetryOptions,
TokenUrl: cfg.TokenCustomUrl,
ServiceAccountKey: serviceAccountKey,
PrivateKey: cfg.PrivateKey,
ClientRetry: cfg.RetryOptions,
TokenUrl: cfg.TokenCustomUrl,
BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext,
}

client := &clients.KeyFlow{}
Expand Down
92 changes: 72 additions & 20 deletions core/clients/key_flow.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clients

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
Expand All @@ -10,8 +11,11 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/stackitcloud/stackit-sdk-go/core/oapierror"

"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
Expand All @@ -36,15 +40,18 @@ type KeyFlow struct {
key *ServiceAccountKeyResponse
privateKey *rsa.PrivateKey
privateKeyPEM []byte
token *TokenResponseBody

tokenMutex sync.RWMutex
token *TokenResponseBody
}

// KeyFlowConfig is the flow config
type KeyFlowConfig struct {
ServiceAccountKey *ServiceAccountKeyResponse
PrivateKey string
ClientRetry *RetryConfig
TokenUrl string
ServiceAccountKey *ServiceAccountKeyResponse
PrivateKey string
ClientRetry *RetryConfig
TokenUrl string
BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil
}

// TokenResponseBody is the API response
Expand Down Expand Up @@ -97,13 +104,19 @@ func (c *KeyFlow) GetServiceAccountEmail() string {

// GetToken returns the token field
func (c *KeyFlow) GetToken() TokenResponseBody {
c.tokenMutex.RLock()
defer c.tokenMutex.RUnlock()

if c.token == nil {
return TokenResponseBody{}
}
// Returned struct is passed by value (because it's a struct)
// So no deepy copy needed
return *c.token
}

func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
// No concurrency at this point, so no mutex check needed
c.token = &TokenResponseBody{}
c.config = cfg
c.doer = Do
Expand All @@ -115,7 +128,14 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
if c.config.ClientRetry == nil {
c.config.ClientRetry = NewRetryConfig()
}
return c.validate()
err := c.validate()
if err != nil {
return err
}
if c.config.BackgroundTokenRefreshContext != nil {
go continuousRefreshToken(c)
}
return nil
}

// SetToken can be used to set an access and refresh token manually in the client.
Expand All @@ -132,13 +152,15 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
return fmt.Errorf("get expiration time from access token: %w", err)
}

c.tokenMutex.Lock()
c.token = &TokenResponseBody{
AccessToken: accessToken,
ExpiresIn: int(exp.Time.Unix()),
Scope: defaultScope,
RefreshToken: refreshToken,
TokenType: defaultTokenType,
}
c.tokenMutex.Unlock()
return nil
}

Expand All @@ -158,17 +180,21 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {

// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
func (c *KeyFlow) GetAccessToken() (string, error) {
accessTokenExpired, err := tokenExpired(c.token.AccessToken)
c.tokenMutex.RLock()
accessToken := c.token.AccessToken
c.tokenMutex.RUnlock()

accessTokenExpired, err := tokenExpired(accessToken)
if err != nil {
return "", fmt.Errorf("failed initial validation: %w", err)
return "", fmt.Errorf("check access token is expired: %w", err)
}
if !accessTokenExpired {
return c.token.AccessToken, nil
return accessToken, nil
}
if err := c.recreateAccessToken(); err != nil {
return "", fmt.Errorf("failed during token recreation: %w", err)
return "", fmt.Errorf("get new access token: %w", err)
}
return c.token.AccessToken, nil
return accessToken, nil
}

// configureHTTPClient configures the HTTP client
Expand All @@ -191,7 +217,7 @@ func (c *KeyFlow) validate() error {
var err error
c.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM([]byte(c.config.PrivateKey))
if err != nil {
return fmt.Errorf("parsing private key from PEM file: %w", err)
return fmt.Errorf("parse private key from PEM file: %w", err)
}

// Encode the private key in PEM format
Expand All @@ -209,7 +235,11 @@ func (c *KeyFlow) validate() error {
// recreateAccessToken is used to create a new access token
// when the existing one isn't valid anymore
func (c *KeyFlow) recreateAccessToken() error {
refreshTokenExpired, err := tokenExpired(c.token.RefreshToken)
c.tokenMutex.RLock()
refreshToken := c.token.RefreshToken
c.tokenMutex.RUnlock()

refreshTokenExpired, err := tokenExpired(refreshToken)
if err != nil {
return err
}
Expand All @@ -233,7 +263,7 @@ func (c *KeyFlow) createAccessToken() (err error) {
defer func() {
tempErr := res.Body.Close()
if tempErr != nil {
err = fmt.Errorf("closing request access token response: %w", tempErr)
err = fmt.Errorf("close request access token response: %w", tempErr)
}
}()
return c.parseTokenResponse(res)
Expand All @@ -242,14 +272,18 @@ func (c *KeyFlow) createAccessToken() (err error) {
// createAccessTokenWithRefreshToken creates an access token using
// an existing pre-validated refresh token
func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) {
res, err := c.requestToken("refresh_token", c.token.RefreshToken)
c.tokenMutex.RLock()
refreshToken := c.token.RefreshToken
c.tokenMutex.RUnlock()

res, err := c.requestToken("refresh_token", refreshToken)
if err != nil {
return err
}
defer func() {
tempErr := res.Body.Close()
if tempErr != nil {
err = fmt.Errorf("closing request access token with refresh token response: %w", tempErr)
err = fmt.Errorf("close request access token with refresh token response: %w", tempErr)
}
}()
return c.parseTokenResponse(res)
Expand Down Expand Up @@ -294,26 +328,44 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error {
return fmt.Errorf("received bad response from API")
}
if res.StatusCode != http.StatusOK {
return fmt.Errorf("received: %+v", res)
body, err := io.ReadAll(res.Body)
if err != nil {
// Fail silently, omit body from error
// We're trying to show error details, so it's unnecessary to fail because of this err
body = []byte{}
}
return &oapierror.GenericOpenAPIError{
StatusCode: res.StatusCode,
Body: body,
ErrorMessage: err.Error(),
}
}
body, err := io.ReadAll(res.Body)
if err != nil {
return err
}

c.tokenMutex.Lock()
c.token = &TokenResponseBody{}
return json.Unmarshal(body, c.token)
err = json.Unmarshal(body, c.token)
c.tokenMutex.Unlock()
if err != nil {
return fmt.Errorf("unmarshal token response: %w", err)
}

return nil
}

func tokenExpired(token string) (bool, error) {
// We can safely use ParseUnverified because we are not authenticating the user at this point.
// We're just checking the expiration time
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
if err != nil {
return false, fmt.Errorf("parse access token: %w", err)
return false, fmt.Errorf("parse token: %w", err)
}
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
if err != nil {
return false, fmt.Errorf("get expiration timestamp from access token: %w", err)
return false, fmt.Errorf("get expiration timestamp: %w", err)
}
expirationTimestamp := expirationTimestampNumeric.Time
now := time.Now()
Expand Down
126 changes: 126 additions & 0 deletions core/clients/key_flow_continuous_refresh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package clients

import (
"errors"
"fmt"
"os"
"time"

"github.com/golang-jwt/jwt/v5"

"github.com/stackitcloud/stackit-sdk-go/core/oapierror"
)

var (
defaultTimeStartBeforeTokenExpiration = 30 * time.Minute
defaultTimeBetweenContextCheck = time.Second
defaultTimeBetweenTries = 5 * time.Minute
)

// Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates.
//
// To terminate this routine, close the context in keyFlow.config.TokenRefreshInBackgroundContext.
func continuousRefreshToken(keyflow *KeyFlow) {
refresher := &continuousTokenRefresher{
keyFlow: keyflow,
timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration,
timeBetweenContextCheck: defaultTimeBetweenContextCheck,
timeBetweenTries: defaultTimeBetweenTries,
}
err := refresher.continuousRefreshToken()
fmt.Fprintf(os.Stderr, "Token refreshing terminated: %v", err)
}

type continuousTokenRefresher struct {
keyFlow *KeyFlow
// Token refresh tries start at [Access token expiration timestamp] - [This duration]
timeStartBeforeTokenExpiration time.Duration
timeBetweenContextCheck time.Duration
timeBetweenTries time.Duration
}

// Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Always returns with a non-nil error.
//
// To terminate this routine, close the context in refresher.keyFlow.config.TokenRefreshInBackgroundContext.
func (refresher *continuousTokenRefresher) continuousRefreshToken() error {
expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp()
if err != nil {
return fmt.Errorf("get access token expiration timestamp: %w", err)
}
startRefreshTimestamp := expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration)

for {
err = refresher.waitUntilTimestamp(startRefreshTimestamp)
if err != nil {
return err
}

err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err()
if err != nil {
return fmt.Errorf("check context: %w", err)
}

ok, err := refresher.refreshToken()
if err != nil {
return fmt.Errorf("refresh tokens: %w", err)
}
if !ok {
startRefreshTimestamp = startRefreshTimestamp.Add(refresher.timeBetweenTries)
continue
}

expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp()
if err != nil {
return fmt.Errorf("get access token expiration timestamp: %w", err)
}
startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration)
}
}

func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) {
token := refresher.keyFlow.token.AccessToken

// We can safely use ParseUnverified because we are not doing authentication of any kind
// We're just checking the expiration time
tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{})
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime()
if err != nil {
return nil, fmt.Errorf("get expiration timestamp: %w", err)
}
return &expirationTimestampNumeric.Time, nil
}

func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error {
for time.Now().Before(timestamp) {
err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err()
if err != nil {
return fmt.Errorf("check context: %w", err)
}
time.Sleep(refresher.timeBetweenContextCheck)
}
return nil
}

// Returns:
// - (true, nil) if successful.
// - (false, nil) if not successful but should be retried.
// - (_, err) if not successful and shouldn't be retried.
func (refresher *continuousTokenRefresher) refreshToken() (bool, error) {
err := refresher.keyFlow.createAccessTokenWithRefreshToken()
if err == nil {
return true, nil
}

// Should be retired if this is an API error with status code non-5xx
oapiErr := &oapierror.GenericOpenAPIError{}
if !errors.As(err, &oapiErr) {
return false, err
}
if oapiErr.StatusCode < 500 {
return false, err
}
return false, nil
}
Loading