Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions sdk/idp_access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/sdk/internal/crypto"
"github.com/opentdf/platform/sdk/internal/oauth"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -116,7 +115,7 @@ func NewIDPAccessTokenSource(
}

// use a pointer receiver so that the token state is shared
func (t *IDPAccessTokenSource) GetAccessToken() (AccessToken, error) {
func (t *IDPAccessTokenSource) AccessToken() (AccessToken, error) {
if t.token == nil {
err := t.RefreshAccessToken()
if err != nil {
Expand All @@ -127,7 +126,7 @@ func (t *IDPAccessTokenSource) GetAccessToken() (AccessToken, error) {
return AccessToken(t.token.AccessToken), nil
}

func (t *IDPAccessTokenSource) GetAsymDecryption() crypto.AsymDecryption {
func (t *IDPAccessTokenSource) AsymDecryption() crypto.AsymDecryption {
return t.asymDecryption
}

Expand All @@ -144,15 +143,10 @@ func (t *IDPAccessTokenSource) RefreshAccessToken() error {
return nil
}

func (t *IDPAccessTokenSource) SignToken(tok jwt.Token) ([]byte, error) {
signed, err := jwt.Sign(tok, jwt.WithKey(t.dpopKey.Algorithm(), t.dpopKey))
if err != nil {
return nil, fmt.Errorf("error signing DPOP token: %w", err)
}

return signed, nil
func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(t.dpopKey)
}

func (t *IDPAccessTokenSource) GetDPoPPublicKeyPEM() string {
func (t *IDPAccessTokenSource) DPOPPublicKeyPEM() string {
return t.dpopPEM
}
25 changes: 17 additions & 8 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/url"
"time"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
kas "github.com/opentdf/backend-go/pkg/access"
"github.com/opentdf/platform/sdk/internal/crypto"
Expand All @@ -27,12 +28,12 @@ type KASClient struct {
type AccessToken string

type AccessTokenSource interface {
GetAccessToken() (AccessToken, error)
AccessToken() (AccessToken, error)
// probably better to use `crypto.AsymDecryption` here than roll our own since this should be
// more closely linked to what happens in KAS in terms of crypto params
GetAsymDecryption() crypto.AsymDecryption
SignToken(jwt.Token) ([]byte, error)
GetDPoPPublicKeyPEM() string
AsymDecryption() crypto.AsymDecryption
MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error)
DPOPPublicKeyPEM() string
RefreshAccessToken() error
}

Expand Down Expand Up @@ -93,7 +94,7 @@ func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) {
}
}

key, err := k.accessTokenSource.GetAsymDecryption().Decrypt(response.EntityWrappedKey)
key, err := k.accessTokenSource.AsymDecryption().Decrypt(response.EntityWrappedKey)
if err != nil {
return nil, fmt.Errorf("error decrypting payload from KAS: %w", err)
}
Expand Down Expand Up @@ -121,7 +122,7 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R
requestBody := rewrapRequestBody{
Policy: policy,
KeyAccess: keyAccess,
ClientPublicKey: k.accessTokenSource.GetDPoPPublicKeyPEM(),
ClientPublicKey: k.accessTokenSource.DPOPPublicKeyPEM(),
}

requestBodyJSON, err := json.Marshal(requestBody)
Expand All @@ -139,12 +140,20 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R
return nil, fmt.Errorf("failed to create jwt: %w", err)
}

signedToken, err := k.accessTokenSource.SignToken(tok)
signedToken, err := k.accessTokenSource.MakeToken(func(key jwk.Key) ([]byte, error) {
signed, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key))
if err != nil {
return nil, fmt.Errorf("error signing DPOP token: %w", err)
}

return signed, nil
})

if err != nil {
return nil, fmt.Errorf("failed to sign the token: %w", err)
}

accessToken, err := k.accessTokenSource.GetAccessToken()
accessToken, err := k.accessTokenSource.AccessToken()
if err != nil {
return nil, fmt.Errorf("error getting access token: %w", err)
}
Expand Down
15 changes: 5 additions & 10 deletions sdk/kas_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package sdk
import (
"encoding/json"
"errors"
"fmt"
"testing"

"github.com/lestrrat-go/jwx/v2/jwa"
Expand All @@ -18,20 +17,16 @@ type FakeAccessTokenSource struct {
accessToken string
}

func (fake FakeAccessTokenSource) GetAccessToken() (AccessToken, error) {
func (fake FakeAccessTokenSource) AccessToken() (AccessToken, error) {
return AccessToken(fake.accessToken), nil
}
func (fake FakeAccessTokenSource) GetAsymDecryption() crypto.AsymDecryption {
func (fake FakeAccessTokenSource) AsymDecryption() crypto.AsymDecryption {
return fake.asymDecryption
}
func (fake FakeAccessTokenSource) SignToken(tok jwt.Token) ([]byte, error) {
signed, err := jwt.Sign(tok, jwt.WithKey(fake.dPOPKey.Algorithm(), fake.dPOPKey))
if err != nil {
return nil, fmt.Errorf("error signing DPOP token: %w", err)
}
return signed, nil
func (fake FakeAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(fake.dPOPKey)
}
func (fake FakeAccessTokenSource) GetDPoPPublicKeyPEM() string {
func (fake FakeAccessTokenSource) DPOPPublicKeyPEM() string {
return "this is the PEM"
}
func (fake FakeAccessTokenSource) RefreshAccessToken() error {
Expand Down