diff --git a/sdk/auth/access_token_source.go b/sdk/auth/access_token_source.go index dd72dc6a03..b685d0bdcc 100644 --- a/sdk/auth/access_token_source.go +++ b/sdk/auth/access_token_source.go @@ -11,5 +11,4 @@ type AccessTokenSource interface { DecryptWithDPoPKey(data []byte) ([]byte, error) MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error) DPoPPublicKeyPEM() string - RefreshAccessToken() error } diff --git a/sdk/idp_access_token_source.go b/sdk/idp_access_token_source.go index 30c398deba..3d51fdb07d 100644 --- a/sdk/idp_access_token_source.go +++ b/sdk/idp_access_token_source.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "log/slog" "net/url" "strings" "sync" @@ -15,7 +16,6 @@ import ( "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/sdk/internal/crypto" "github.com/opentdf/platform/sdk/internal/oauth" - "golang.org/x/oauth2" ) const ( @@ -81,7 +81,7 @@ to a DPoP key type IDPAccessTokenSource struct { credentials oauth.ClientCredentials idpTokenEndpoint url.URL - token *oauth2.Token + token *oauth.Token scopes []string dpopKey jwk.Key asymDecryption crypto.AsymDecryption @@ -117,11 +117,16 @@ func NewIDPAccessTokenSource( // use a pointer receiver so that the token state is shared func (t *IDPAccessTokenSource) AccessToken() (auth.AccessToken, error) { - if t.token == nil { - err := t.RefreshAccessToken() + t.tokenMutex.Lock() + defer t.tokenMutex.Unlock() + + if t.token == nil || t.token.Expired() { + slog.Debug("getting new access token") + tok, err := oauth.GetAccessToken(t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey) if err != nil { - return auth.AccessToken(""), err + return "", fmt.Errorf("error getting access token: %w", err) } + t.token = tok } return auth.AccessToken(t.token.AccessToken), nil @@ -131,19 +136,6 @@ func (t *IDPAccessTokenSource) DecryptWithDPoPKey(data []byte) ([]byte, error) { return t.asymDecryption.Decrypt(data) } -func (t *IDPAccessTokenSource) RefreshAccessToken() error { - t.tokenMutex.Lock() - defer t.tokenMutex.Unlock() - - tok, err := oauth.GetAccessToken(t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey) - if err != nil { - return fmt.Errorf("error getting access token: %w", err) - } - t.token = tok - - return nil -} - func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(t.dpopKey) } diff --git a/sdk/internal/oauth/oauth.go b/sdk/internal/oauth/oauth.go index a19e111432..4c0f2e1e9f 100644 --- a/sdk/internal/oauth/oauth.go +++ b/sdk/internal/oauth/oauth.go @@ -15,7 +15,10 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" - "golang.org/x/oauth2" +) + +const ( + tokenExpirationBuffer = 10 * time.Second ) type ClientCredentials struct { @@ -23,6 +26,24 @@ type ClientCredentials struct { ClientId string } +type Token struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` + received time.Time +} + +func (t Token) Expired() bool { + if t.ExpiresIn == 0 { + return false + } + + expirationTime := t.received.Add(time.Second * time.Duration(t.ExpiresIn)) + + return time.Now().After(expirationTime.Add(-tokenExpirationBuffer)) +} + func getRequest(tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, privateJWK *jwk.Key) (*http.Request, error) { req, err := http.NewRequest("POST", tokenEndpoint, nil) if err != nil { @@ -96,7 +117,7 @@ func getSignedToken(clientID, tokenEndpoint string, key jwk.Key) ([]byte, error) // this misses the flow where the Authorization server can tell us the next nonce to use. // missing this flow costs us a bit in efficiency (a round trip per access token) but this is // still correct because -func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*oauth2.Token, error) { +func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*Token, error) { req, err := getRequest(tokenEndpoint, "", scopes, clientCredentials, &dpopPrivateKey) if err != nil { return nil, err @@ -128,7 +149,7 @@ func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials Cli return processResponse(resp) } -func processResponse(resp *http.Response) (*oauth2.Token, error) { +func processResponse(resp *http.Response) (*Token, error) { respBytes, err := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -139,11 +160,13 @@ func processResponse(resp *http.Response) (*oauth2.Token, error) { return nil, fmt.Errorf("error reading bytes from response: %w", err) } - var token *oauth2.Token + var token *Token if err := json.Unmarshal(respBytes, &token); err != nil { return nil, fmt.Errorf("error unmarshaling token from response: %w", err) } + token.received = time.Now() + return token, nil } diff --git a/sdk/internal/oauth/oauth_test.go b/sdk/internal/oauth/oauth_test.go index 614aa0615a..a8f10cc931 100644 --- a/sdk/internal/oauth/oauth_test.go +++ b/sdk/internal/oauth/oauth_test.go @@ -1,4 +1,4 @@ -package oauth_test +package oauth import ( "context" @@ -22,7 +22,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/opentdf/platform/sdk/internal/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" tc "github.com/testcontainers/testcontainers-go" @@ -49,12 +48,12 @@ func TestGettingAccessTokenFromKeycloak(t *testing.T) { require.NoError(t, dpopJWK.Set("use", "sig")) require.NoError(t, dpopJWK.Set("alg", jwa.RS256.String())) - clientCredentials := oauth.ClientCredentials{ + clientCredentials := ClientCredentials{ ClientId: "testclient", ClientAuth: "abcd1234", } - tok, err := oauth.GetAccessToken( + tok, err := GetAccessToken( idpEndpoint, []string{"testscope"}, clientCredentials, @@ -83,6 +82,15 @@ func TestGettingAccessTokenFromKeycloak(t *testing.T) { } else { t.Fatal("no cnf claim in token") } + + if tok.ExpiresIn < 0 { + t.Fatalf("invalid expiration is before current time: %v", tok) + } + + if tok.Expired() { + t.Fatalf("got a token that is currently expired: %v", tok) + } + } func TestClientSecretNoNonce(t *testing.T) { @@ -119,11 +127,11 @@ func TestClientSecretNoNonce(t *testing.T) { })) defer server.Close() - clientCredentials := oauth.ClientCredentials{ + clientCredentials := ClientCredentials{ ClientId: "theclient", ClientAuth: "thesecret", } - _, err = oauth.GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK) + _, err = GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK) require.NoError(t, err, "didn't get a token back from the IdP") } @@ -188,16 +196,44 @@ func TestClientSecretWithNonce(t *testing.T) { })) defer server.Close() - clientCredentials := oauth.ClientCredentials{ + clientCredentials := ClientCredentials{ ClientId: "theclient", ClientAuth: "thesecret", } - _, err = oauth.GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK) + _, err = GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK) if err != nil { t.Errorf("didn't get a token back from the IdP: %v", err) } } +func TestTokenExpiration_RespectsLeeway(t *testing.T) { + expiredToken := Token{ + received: time.Now().Add(-tokenExpirationBuffer - 10*time.Second), + ExpiresIn: 5, + } + if !expiredToken.Expired() { + t.Fatalf("token should be expired") + } + + goodToken := Token{ + received: time.Now(), + ExpiresIn: 2 * int64(tokenExpirationBuffer/time.Second), + } + + if goodToken.Expired() { + t.Fatalf("token should not be expired") + } + + justOverBorderToken := Token{ + received: time.Now(), + ExpiresIn: int64(tokenExpirationBuffer/time.Second) - 1, + } + + if !justOverBorderToken.Expired() { + t.Fatalf("token should not be expired") + } +} + func TestSignedJWTWithNonce(t *testing.T) { // Generate RSA Key to use for DPoP dpopKey, err := rsa.GenerateKey(rand.Reader, 4096) @@ -273,14 +309,14 @@ func TestSignedJWTWithNonce(t *testing.T) { })) defer server.Close() - clientCredentials := oauth.ClientCredentials{ + clientCredentials := ClientCredentials{ ClientId: "theclient", ClientAuth: clientAuthJWK, } url = server.URL + "/token" - _, err = oauth.GetAccessToken(url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK) + _, err = GetAccessToken(url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK) if err != nil { t.Errorf("didn't get a token back from the IdP: %v", err) } diff --git a/sdk/kas_client.go b/sdk/kas_client.go index 86ca168ae7..a87eb01e63 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -12,8 +12,6 @@ import ( kas "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/sdk/auth" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) const ( @@ -66,22 +64,6 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas. func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) { response, err := k.makeRewrapRequest(keyAccess, policy) - if err != nil { - switch status.Code(err) { //nolint:exhaustive // we can only handle authentication - case codes.Unauthenticated: - err = k.accessTokenSource.RefreshAccessToken() - if err != nil { - return nil, fmt.Errorf("error refreshing access token: %w", err) - } - response, err = k.makeRewrapRequest(keyAccess, policy) - if err != nil { - return nil, fmt.Errorf("Error making rewrap request: %w", err) - } - default: - return nil, fmt.Errorf("Error making rewrap request: %w", err) - } - } - key, err := k.accessTokenSource.DecryptWithDPoPKey(response.GetEntityWrappedKey()) if err != nil { return nil, fmt.Errorf("error decrypting payload from KAS: %w", err)