diff --git a/sdk/idp_access_token_source.go b/sdk/idp_access_token_source.go index 14b7357353..e56ec62220 100644 --- a/sdk/idp_access_token_source.go +++ b/sdk/idp_access_token_source.go @@ -12,7 +12,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/sdk/internal/crypto" "github.com/opentdf/platform/sdk/internal/oauth" "golang.org/x/oauth2" @@ -116,7 +115,7 @@ func NewIDPAccessTokenSource( } // use a pointer receiver so that the token state is shared -func (t *IDPAccessTokenSource) GetAccessToken() (AccessToken, error) { +func (t *IDPAccessTokenSource) AccessToken() (AccessToken, error) { if t.token == nil { err := t.RefreshAccessToken() if err != nil { @@ -127,7 +126,7 @@ func (t *IDPAccessTokenSource) GetAccessToken() (AccessToken, error) { return AccessToken(t.token.AccessToken), nil } -func (t *IDPAccessTokenSource) GetAsymDecryption() crypto.AsymDecryption { +func (t *IDPAccessTokenSource) AsymDecryption() crypto.AsymDecryption { return t.asymDecryption } @@ -144,15 +143,10 @@ func (t *IDPAccessTokenSource) RefreshAccessToken() error { return nil } -func (t *IDPAccessTokenSource) SignToken(tok jwt.Token) ([]byte, error) { - signed, err := jwt.Sign(tok, jwt.WithKey(t.dpopKey.Algorithm(), t.dpopKey)) - if err != nil { - return nil, fmt.Errorf("error signing DPOP token: %w", err) - } - - return signed, nil +func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { + return tokenMaker(t.dpopKey) } -func (t *IDPAccessTokenSource) GetDPoPPublicKeyPEM() string { +func (t *IDPAccessTokenSource) DPOPPublicKeyPEM() string { return t.dpopPEM } diff --git a/sdk/kas_client.go b/sdk/kas_client.go index bc7d2abe0a..999384cf9a 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -7,6 +7,7 @@ import ( "net/url" "time" + "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" kas "github.com/opentdf/backend-go/pkg/access" "github.com/opentdf/platform/sdk/internal/crypto" @@ -27,12 +28,12 @@ type KASClient struct { type AccessToken string type AccessTokenSource interface { - GetAccessToken() (AccessToken, error) + AccessToken() (AccessToken, error) // probably better to use `crypto.AsymDecryption` here than roll our own since this should be // more closely linked to what happens in KAS in terms of crypto params - GetAsymDecryption() crypto.AsymDecryption - SignToken(jwt.Token) ([]byte, error) - GetDPoPPublicKeyPEM() string + AsymDecryption() crypto.AsymDecryption + MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error) + DPOPPublicKeyPEM() string RefreshAccessToken() error } @@ -93,7 +94,7 @@ func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) { } } - key, err := k.accessTokenSource.GetAsymDecryption().Decrypt(response.EntityWrappedKey) + key, err := k.accessTokenSource.AsymDecryption().Decrypt(response.EntityWrappedKey) if err != nil { return nil, fmt.Errorf("error decrypting payload from KAS: %w", err) } @@ -121,7 +122,7 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R requestBody := rewrapRequestBody{ Policy: policy, KeyAccess: keyAccess, - ClientPublicKey: k.accessTokenSource.GetDPoPPublicKeyPEM(), + ClientPublicKey: k.accessTokenSource.DPOPPublicKeyPEM(), } requestBodyJSON, err := json.Marshal(requestBody) @@ -139,12 +140,20 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R return nil, fmt.Errorf("failed to create jwt: %w", err) } - signedToken, err := k.accessTokenSource.SignToken(tok) + signedToken, err := k.accessTokenSource.MakeToken(func(key jwk.Key) ([]byte, error) { + signed, err := jwt.Sign(tok, jwt.WithKey(key.Algorithm(), key)) + if err != nil { + return nil, fmt.Errorf("error signing DPOP token: %w", err) + } + + return signed, nil + }) + if err != nil { return nil, fmt.Errorf("failed to sign the token: %w", err) } - accessToken, err := k.accessTokenSource.GetAccessToken() + accessToken, err := k.accessTokenSource.AccessToken() if err != nil { return nil, fmt.Errorf("error getting access token: %w", err) } diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 94ead76681..3803b45e80 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -3,7 +3,6 @@ package sdk import ( "encoding/json" "errors" - "fmt" "testing" "github.com/lestrrat-go/jwx/v2/jwa" @@ -18,20 +17,16 @@ type FakeAccessTokenSource struct { accessToken string } -func (fake FakeAccessTokenSource) GetAccessToken() (AccessToken, error) { +func (fake FakeAccessTokenSource) AccessToken() (AccessToken, error) { return AccessToken(fake.accessToken), nil } -func (fake FakeAccessTokenSource) GetAsymDecryption() crypto.AsymDecryption { +func (fake FakeAccessTokenSource) AsymDecryption() crypto.AsymDecryption { return fake.asymDecryption } -func (fake FakeAccessTokenSource) SignToken(tok jwt.Token) ([]byte, error) { - signed, err := jwt.Sign(tok, jwt.WithKey(fake.dPOPKey.Algorithm(), fake.dPOPKey)) - if err != nil { - return nil, fmt.Errorf("error signing DPOP token: %w", err) - } - return signed, nil +func (fake FakeAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { + return tokenMaker(fake.dPOPKey) } -func (fake FakeAccessTokenSource) GetDPoPPublicKeyPEM() string { +func (fake FakeAccessTokenSource) DPOPPublicKeyPEM() string { return "this is the PEM" } func (fake FakeAccessTokenSource) RefreshAccessToken() error {