Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sdk/auth/access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ type AccessTokenSource interface {
DecryptWithDPoPKey(data []byte) ([]byte, error)
MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error)
DPoPPublicKeyPEM() string
RefreshAccessToken() error
}
28 changes: 10 additions & 18 deletions sdk/idp_access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"log/slog"
"net/url"
"strings"
"sync"
Expand All @@ -15,7 +16,6 @@ import (
"github.com/opentdf/platform/sdk/auth"
"github.com/opentdf/platform/sdk/internal/crypto"
"github.com/opentdf/platform/sdk/internal/oauth"
"golang.org/x/oauth2"
)

const (
Expand Down Expand Up @@ -81,7 +81,7 @@ to a DPoP key
type IDPAccessTokenSource struct {
credentials oauth.ClientCredentials
idpTokenEndpoint url.URL
token *oauth2.Token
token *oauth.Token
scopes []string
dpopKey jwk.Key
asymDecryption crypto.AsymDecryption
Expand Down Expand Up @@ -117,11 +117,16 @@ func NewIDPAccessTokenSource(

// use a pointer receiver so that the token state is shared
func (t *IDPAccessTokenSource) AccessToken() (auth.AccessToken, error) {
if t.token == nil {
err := t.RefreshAccessToken()
t.tokenMutex.Lock()
defer t.tokenMutex.Unlock()

if t.token == nil || t.token.Expired() {
slog.Debug("getting new access token")
tok, err := oauth.GetAccessToken(t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey)
if err != nil {
return auth.AccessToken(""), err
return "", fmt.Errorf("error getting access token: %w", err)
}
t.token = tok
}

return auth.AccessToken(t.token.AccessToken), nil
Expand All @@ -131,19 +136,6 @@ func (t *IDPAccessTokenSource) DecryptWithDPoPKey(data []byte) ([]byte, error) {
return t.asymDecryption.Decrypt(data)
}

func (t *IDPAccessTokenSource) RefreshAccessToken() error {
t.tokenMutex.Lock()
defer t.tokenMutex.Unlock()

tok, err := oauth.GetAccessToken(t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey)
if err != nil {
return fmt.Errorf("error getting access token: %w", err)
}
t.token = tok

return nil
}

func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(t.dpopKey)
}
Expand Down
31 changes: 27 additions & 4 deletions sdk/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,35 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"golang.org/x/oauth2"
)

const (
tokenExpirationBuffer = 10 * time.Second
)

type ClientCredentials struct {
ClientAuth interface{} // the supported types for this are a JWK (implying jwt-bearer auth) or a string (implying client secret auth)
ClientId string
}

type Token struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in,omitempty"`
Scope string `json:"scope,omitempty"`
received time.Time
}

func (t Token) Expired() bool {
if t.ExpiresIn == 0 {
return false
}

expirationTime := t.received.Add(time.Second * time.Duration(t.ExpiresIn))

return time.Now().After(expirationTime.Add(-tokenExpirationBuffer))
}

func getRequest(tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, privateJWK *jwk.Key) (*http.Request, error) {
req, err := http.NewRequest("POST", tokenEndpoint, nil)
if err != nil {
Expand Down Expand Up @@ -96,7 +117,7 @@ func getSignedToken(clientID, tokenEndpoint string, key jwk.Key) ([]byte, error)
// this misses the flow where the Authorization server can tell us the next nonce to use.
// missing this flow costs us a bit in efficiency (a round trip per access token) but this is
// still correct because
func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*oauth2.Token, error) {
func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*Token, error) {
req, err := getRequest(tokenEndpoint, "", scopes, clientCredentials, &dpopPrivateKey)
if err != nil {
return nil, err
Expand Down Expand Up @@ -128,7 +149,7 @@ func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials Cli
return processResponse(resp)
}

func processResponse(resp *http.Response) (*oauth2.Token, error) {
func processResponse(resp *http.Response) (*Token, error) {
respBytes, err := io.ReadAll(resp.Body)

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
Expand All @@ -139,11 +160,13 @@ func processResponse(resp *http.Response) (*oauth2.Token, error) {
return nil, fmt.Errorf("error reading bytes from response: %w", err)
}

var token *oauth2.Token
var token *Token
if err := json.Unmarshal(respBytes, &token); err != nil {
return nil, fmt.Errorf("error unmarshaling token from response: %w", err)
}

token.received = time.Now()

return token, nil
}

Expand Down
56 changes: 46 additions & 10 deletions sdk/internal/oauth/oauth_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oauth_test
package oauth

import (
"context"
Expand All @@ -22,7 +22,6 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/sdk/internal/oauth"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tc "github.com/testcontainers/testcontainers-go"
Expand All @@ -49,12 +48,12 @@ func TestGettingAccessTokenFromKeycloak(t *testing.T) {
require.NoError(t, dpopJWK.Set("use", "sig"))
require.NoError(t, dpopJWK.Set("alg", jwa.RS256.String()))

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "testclient",
ClientAuth: "abcd1234",
}

tok, err := oauth.GetAccessToken(
tok, err := GetAccessToken(
idpEndpoint,
[]string{"testscope"},
clientCredentials,
Expand Down Expand Up @@ -83,6 +82,15 @@ func TestGettingAccessTokenFromKeycloak(t *testing.T) {
} else {
t.Fatal("no cnf claim in token")
}

if tok.ExpiresIn < 0 {
t.Fatalf("invalid expiration is before current time: %v", tok)
}

if tok.Expired() {
t.Fatalf("got a token that is currently expired: %v", tok)
}

}

func TestClientSecretNoNonce(t *testing.T) {
Expand Down Expand Up @@ -119,11 +127,11 @@ func TestClientSecretNoNonce(t *testing.T) {
}))
defer server.Close()

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "theclient",
ClientAuth: "thesecret",
}
_, err = oauth.GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
require.NoError(t, err, "didn't get a token back from the IdP")
}

Expand Down Expand Up @@ -188,16 +196,44 @@ func TestClientSecretWithNonce(t *testing.T) {
}))
defer server.Close()

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "theclient",
ClientAuth: "thesecret",
}
_, err = oauth.GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
if err != nil {
t.Errorf("didn't get a token back from the IdP: %v", err)
}
}

func TestTokenExpiration_RespectsLeeway(t *testing.T) {
expiredToken := Token{
received: time.Now().Add(-tokenExpirationBuffer - 10*time.Second),
ExpiresIn: 5,
}
if !expiredToken.Expired() {
t.Fatalf("token should be expired")
}

goodToken := Token{
received: time.Now(),
ExpiresIn: 2 * int64(tokenExpirationBuffer/time.Second),
}

if goodToken.Expired() {
t.Fatalf("token should not be expired")
}

justOverBorderToken := Token{
received: time.Now(),
ExpiresIn: int64(tokenExpirationBuffer/time.Second) - 1,
}

if !justOverBorderToken.Expired() {
t.Fatalf("token should not be expired")
}
}

func TestSignedJWTWithNonce(t *testing.T) {
// Generate RSA Key to use for DPoP
dpopKey, err := rsa.GenerateKey(rand.Reader, 4096)
Expand Down Expand Up @@ -273,14 +309,14 @@ func TestSignedJWTWithNonce(t *testing.T) {
}))
defer server.Close()

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "theclient",
ClientAuth: clientAuthJWK,
}

url = server.URL + "/token"

_, err = oauth.GetAccessToken(url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
if err != nil {
t.Errorf("didn't get a token back from the IdP: %v", err)
}
Expand Down
18 changes: 0 additions & 18 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
kas "github.com/opentdf/platform/protocol/go/kas"
"github.com/opentdf/platform/sdk/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
Expand Down Expand Up @@ -66,22 +64,6 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.
func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) {
response, err := k.makeRewrapRequest(keyAccess, policy)

if err != nil {
switch status.Code(err) { //nolint:exhaustive // we can only handle authentication
case codes.Unauthenticated:
err = k.accessTokenSource.RefreshAccessToken()
if err != nil {
return nil, fmt.Errorf("error refreshing access token: %w", err)
}
response, err = k.makeRewrapRequest(keyAccess, policy)
if err != nil {
return nil, fmt.Errorf("Error making rewrap request: %w", err)
}
default:
return nil, fmt.Errorf("Error making rewrap request: %w", err)
}
}

key, err := k.accessTokenSource.DecryptWithDPoPKey(response.GetEntityWrappedKey())
if err != nil {
return nil, fmt.Errorf("error decrypting payload from KAS: %w", err)
Expand Down