Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sdk/auth/access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ type AccessTokenSource interface {
DecryptWithDPoPKey(data []byte) ([]byte, error)
MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error)
DPoPPublicKeyPEM() string
RefreshAccessToken() error
}
28 changes: 10 additions & 18 deletions sdk/idp_access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"log/slog"
"net/url"
"strings"
"sync"
Expand All @@ -15,7 +16,6 @@ import (
"github.com/opentdf/platform/sdk/auth"
"github.com/opentdf/platform/sdk/internal/crypto"
"github.com/opentdf/platform/sdk/internal/oauth"
"golang.org/x/oauth2"
)

const (
Expand Down Expand Up @@ -81,7 +81,7 @@ to a DPoP key
type IDPAccessTokenSource struct {
credentials oauth.ClientCredentials
idpTokenEndpoint url.URL
token *oauth2.Token
token *oauth.Token
scopes []string
dpopKey jwk.Key
asymDecryption crypto.AsymDecryption
Expand Down Expand Up @@ -117,11 +117,16 @@ func NewIDPAccessTokenSource(

// use a pointer receiver so that the token state is shared
func (t *IDPAccessTokenSource) AccessToken() (auth.AccessToken, error) {
if t.token == nil {
err := t.RefreshAccessToken()
t.tokenMutex.Lock()
defer t.tokenMutex.Unlock()

if t.token == nil || t.token.Expired() {
slog.Debug("getting new access token")
tok, err := oauth.GetAccessToken(t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey)
if err != nil {
return auth.AccessToken(""), err
return "", fmt.Errorf("error getting access token: %w", err)
}
t.token = tok
}

return auth.AccessToken(t.token.AccessToken), nil
Expand All @@ -131,19 +136,6 @@ func (t *IDPAccessTokenSource) DecryptWithDPoPKey(data []byte) ([]byte, error) {
return t.asymDecryption.Decrypt(data)
}

func (t *IDPAccessTokenSource) RefreshAccessToken() error {
t.tokenMutex.Lock()
defer t.tokenMutex.Unlock()

tok, err := oauth.GetAccessToken(t.idpTokenEndpoint.String(), t.scopes, t.credentials, t.dpopKey)
if err != nil {
return fmt.Errorf("error getting access token: %w", err)
}
t.token = tok

return nil
}

func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(t.dpopKey)
}
Expand Down
41 changes: 35 additions & 6 deletions sdk/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,37 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"golang.org/x/oauth2"
)

const (
tokenExpirationBuffer = 10 * time.Second
)

type ClientCredentials struct {
ClientAuth interface{} // the supported types for this are a JWK (implying jwt-bearer auth) or a string (implying client secret auth)
ClientId string
}

type tokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in,omitempty"`
Scope string `json:"scope,omitempty"`
}

type Token struct {
AccessToken string
expiry time.Time
}

func (t Token) Expired() bool {
if t.expiry.IsZero() {
return false
}

return time.Now().After(t.expiry.Add(-tokenExpirationBuffer))
}

func getRequest(tokenEndpoint, dpopNonce string, scopes []string, clientCredentials ClientCredentials, privateJWK *jwk.Key) (*http.Request, error) {
req, err := http.NewRequest("POST", tokenEndpoint, nil)
if err != nil {
Expand Down Expand Up @@ -96,7 +119,7 @@ func getSignedToken(clientID, tokenEndpoint string, key jwk.Key) ([]byte, error)
// this misses the flow where the Authorization server can tell us the next nonce to use.
// missing this flow costs us a bit in efficiency (a round trip per access token) but this is
// still correct because
func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*oauth2.Token, error) {
func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials ClientCredentials, dpopPrivateKey jwk.Key) (*Token, error) {
req, err := getRequest(tokenEndpoint, "", scopes, clientCredentials, &dpopPrivateKey)
if err != nil {
return nil, err
Expand Down Expand Up @@ -128,7 +151,7 @@ func GetAccessToken(tokenEndpoint string, scopes []string, clientCredentials Cli
return processResponse(resp)
}

func processResponse(resp *http.Response) (*oauth2.Token, error) {
func processResponse(resp *http.Response) (*Token, error) {
respBytes, err := io.ReadAll(resp.Body)

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
Expand All @@ -139,12 +162,18 @@ func processResponse(resp *http.Response) (*oauth2.Token, error) {
return nil, fmt.Errorf("error reading bytes from response: %w", err)
}

var token *oauth2.Token
if err := json.Unmarshal(respBytes, &token); err != nil {
var tokenResponse *tokenResponse
if err := json.Unmarshal(respBytes, &tokenResponse); err != nil {
return nil, fmt.Errorf("error unmarshaling token from response: %w", err)
}

return token, nil
var token Token
if tokenResponse.ExpiresIn != 0 {
token.expiry = time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second)
}
token.AccessToken = tokenResponse.AccessToken

return &token, nil
}

func getDPoPAssertion(dpopJWK jwk.Key, method string, endpoint string, nonce string) (string, error) {
Expand Down
38 changes: 28 additions & 10 deletions sdk/internal/oauth/oauth_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oauth_test
package oauth

import (
"context"
Expand All @@ -22,7 +22,6 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/sdk/internal/oauth"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tc "github.com/testcontainers/testcontainers-go"
Expand All @@ -49,12 +48,12 @@ func TestGettingAccessTokenFromKeycloak(t *testing.T) {
require.NoError(t, dpopJWK.Set("use", "sig"))
require.NoError(t, dpopJWK.Set("alg", jwa.RS256.String()))

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "testclient",
ClientAuth: "abcd1234",
}

tok, err := oauth.GetAccessToken(
tok, err := GetAccessToken(
idpEndpoint,
[]string{"testscope"},
clientCredentials,
Expand Down Expand Up @@ -83,6 +82,15 @@ func TestGettingAccessTokenFromKeycloak(t *testing.T) {
} else {
t.Fatal("no cnf claim in token")
}

if tok.expiry.Before(time.Now()) {
t.Fatalf("invalid expiration is before current time: %v", tok.expiry)
}

if tok.Expired() {
t.Fatalf("got a token that is currently expired: %v", tok.expiry)
}

}

func TestClientSecretNoNonce(t *testing.T) {
Expand Down Expand Up @@ -119,11 +127,11 @@ func TestClientSecretNoNonce(t *testing.T) {
}))
defer server.Close()

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "theclient",
ClientAuth: "thesecret",
}
_, err = oauth.GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
require.NoError(t, err, "didn't get a token back from the IdP")
}

Expand Down Expand Up @@ -188,16 +196,26 @@ func TestClientSecretWithNonce(t *testing.T) {
}))
defer server.Close()

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "theclient",
ClientAuth: "thesecret",
}
_, err = oauth.GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(server.URL+"/token", []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
if err != nil {
t.Errorf("didn't get a token back from the IdP: %v", err)
}
}

func TestTokenExpiration_RespectsLeeway(t *testing.T) {
if !(Token{expiry: time.Now().Add(-tokenExpirationBuffer - 10*time.Second)}).Expired() {
t.Fatalf("token should be expired")
}

if (Token{expiry: time.Now().Add(tokenExpirationBuffer + 10*time.Second)}).Expired() {
t.Fatalf("token should not be expired")
}
}

func TestSignedJWTWithNonce(t *testing.T) {
// Generate RSA Key to use for DPoP
dpopKey, err := rsa.GenerateKey(rand.Reader, 4096)
Expand Down Expand Up @@ -273,14 +291,14 @@ func TestSignedJWTWithNonce(t *testing.T) {
}))
defer server.Close()

clientCredentials := oauth.ClientCredentials{
clientCredentials := ClientCredentials{
ClientId: "theclient",
ClientAuth: clientAuthJWK,
}

url = server.URL + "/token"

_, err = oauth.GetAccessToken(url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
_, err = GetAccessToken(url, []string{"scope1", "scope2"}, clientCredentials, dpopJWK)
if err != nil {
t.Errorf("didn't get a token back from the IdP: %v", err)
}
Expand Down
18 changes: 0 additions & 18 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
kas "github.com/opentdf/platform/protocol/go/kas"
"github.com/opentdf/platform/sdk/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const (
Expand Down Expand Up @@ -66,22 +64,6 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.
func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) {
response, err := k.makeRewrapRequest(keyAccess, policy)

if err != nil {
switch status.Code(err) { //nolint:exhaustive // we can only handle authentication
case codes.Unauthenticated:
err = k.accessTokenSource.RefreshAccessToken()
if err != nil {
return nil, fmt.Errorf("error refreshing access token: %w", err)
}
response, err = k.makeRewrapRequest(keyAccess, policy)
if err != nil {
return nil, fmt.Errorf("Error making rewrap request: %w", err)
}
default:
return nil, fmt.Errorf("Error making rewrap request: %w", err)
}
}

key, err := k.accessTokenSource.DecryptWithDPoPKey(response.GetEntityWrappedKey())
if err != nil {
return nil, fmt.Errorf("error decrypting payload from KAS: %w", err)
Expand Down