diff --git a/internal/auth/authn_test.go b/internal/auth/authn_test.go index 09773e6761..108e81b040 100644 --- a/internal/auth/authn_test.go +++ b/internal/auth/authn_test.go @@ -77,37 +77,19 @@ type FakeTokenSource struct { func (fts *FakeTokenSource) AccessToken() (sdkauth.AccessToken, error) { return sdkauth.AccessToken(fts.accessToken), nil } -func (*FakeTokenSource) DecryptWithDPoPKey([]byte) ([]byte, error) { - return nil, nil -} func (fts *FakeTokenSource) MakeToken(f func(jwk.Key) ([]byte, error)) ([]byte, error) { if fts.key == nil { return nil, errors.New("no such key") } return f(fts.key) } -func (*FakeTokenSource) DPoPPublicKeyPEM() string { - return "" -} -func (*FakeTokenSource) RefreshAccessToken() error { - return nil -} func (fake FakeAccessTokenSource) AccessToken() (sdkauth.AccessToken, error) { return sdkauth.AccessToken(fake.accessToken), nil } -func (fake FakeAccessTokenSource) DecryptWithDPoPKey(_ []byte) ([]byte, error) { - return nil, nil -} func (fake FakeAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(fake.dpopKey) } -func (fake FakeAccessTokenSource) DPoPPublicKeyPEM() string { - return "this is the PEM" -} -func (fake FakeAccessTokenSource) RefreshAccessToken() error { - return errors.New("can't refresh this one") -} func must(err error) { if err != nil { diff --git a/sdk/auth/access_token_source.go b/sdk/auth/access_token_source.go index b685d0bdcc..dde0a259a3 100644 --- a/sdk/auth/access_token_source.go +++ b/sdk/auth/access_token_source.go @@ -8,7 +8,5 @@ type AccessTokenSource interface { AccessToken() (AccessToken, error) // probably better to use `crypto.AsymDecryption` here than roll our own since this should be // more closely linked to what happens in KAS in terms of crypto params - DecryptWithDPoPKey(data []byte) ([]byte, error) MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error) - DPoPPublicKeyPEM() string } diff --git a/sdk/auth/token_adding_interceptor_test.go b/sdk/auth/token_adding_interceptor_test.go index c5d2420665..d959dc242b 100644 --- a/sdk/auth/token_adding_interceptor_test.go +++ b/sdk/auth/token_adding_interceptor_test.go @@ -160,22 +160,12 @@ type FakeTokenSource struct { func (fts *FakeTokenSource) AccessToken() (AccessToken, error) { return AccessToken(fts.accessToken), nil } -func (*FakeTokenSource) DecryptWithDPoPKey([]byte) ([]byte, error) { - return nil, nil -} func (fts *FakeTokenSource) MakeToken(f func(jwk.Key) ([]byte, error)) ([]byte, error) { if fts.key == nil { return nil, errors.New("no such key") } return f(fts.key) } -func (*FakeTokenSource) DPoPPublicKeyPEM() string { - return "" -} -func (*FakeTokenSource) RefreshAccessToken() error { - return nil -} - func runServer(ctx context.Context, //nolint:ireturn // this is pretty concrete f *FakeAccessServiceServer, oo TokenAddingInterceptor) (kas.AccessServiceClient, func()) { buffer := 1024 * 1024 diff --git a/sdk/idp_access_token_source.go b/sdk/idp_access_token_source.go index 3d51fdb07d..7f86a0447f 100644 --- a/sdk/idp_access_token_source.go +++ b/sdk/idp_access_token_source.go @@ -132,14 +132,6 @@ func (t *IDPAccessTokenSource) AccessToken() (auth.AccessToken, error) { return auth.AccessToken(t.token.AccessToken), nil } -func (t *IDPAccessTokenSource) DecryptWithDPoPKey(data []byte) ([]byte, error) { - return t.asymDecryption.Decrypt(data) -} - func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(t.dpopKey) } - -func (t *IDPAccessTokenSource) DPoPPublicKeyPEM() string { - return t.dpopPEM -} diff --git a/sdk/kas_client.go b/sdk/kas_client.go index a87eb01e63..602c0e0040 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -11,6 +11,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" kas "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/sdk/auth" + "github.com/opentdf/platform/sdk/internal/crypto" "google.golang.org/grpc" ) @@ -19,8 +20,10 @@ const ( ) type KASClient struct { - accessTokenSource auth.AccessTokenSource - dialOptions []grpc.DialOption + accessTokenSource auth.AccessTokenSource + dialOptions []grpc.DialOption + clientPublicKeyPEM string + asymDecryption crypto.AsymDecryption } // once the backend moves over we should use the same type that the golang backend uses here @@ -32,6 +35,35 @@ type rewrapRequestBody struct { SchemaVersion string `json:"schemaVersion,omitempty"` } +func newKASClient(dialOptions []grpc.DialOption, accessTokenSource auth.AccessTokenSource) (*KASClient, error) { + rsaKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize) + if err != nil { + return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err) + } + + clientPublicKey, err := rsaKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err) + } + + clientPrivateKey, err := rsaKeyPair.PrivateKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("crypto.PrivateKeyInPemFormat failed: %w", err) + } + + asymDecryption, err := crypto.NewAsymDecryption(clientPrivateKey) + if err != nil { + return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err) + } + + return &KASClient{ + accessTokenSource: accessTokenSource, + dialOptions: dialOptions, + clientPublicKeyPEM: clientPublicKey, + asymDecryption: asymDecryption, + }, nil +} + // there is no connection caching as of now func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.RewrapResponse, error) { rewrapRequest, err := k.getRewrapRequest(keyAccess, policy) @@ -64,7 +96,7 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas. func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) { response, err := k.makeRewrapRequest(keyAccess, policy) - key, err := k.accessTokenSource.DecryptWithDPoPKey(response.GetEntityWrappedKey()) + key, err := k.asymDecryption.Decrypt(response.GetEntityWrappedKey()) if err != nil { return nil, fmt.Errorf("error decrypting payload from KAS: %w", err) } @@ -92,7 +124,7 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R requestBody := rewrapRequestBody{ Policy: policy, KeyAccess: keyAccess, - ClientPublicKey: k.accessTokenSource.DPoPPublicKeyPEM(), + ClientPublicKey: k.clientPublicKeyPEM, } requestBodyJSON, err := json.Marshal(requestBody) if err != nil { diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 70ee748daf..dae8576797 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -2,9 +2,10 @@ package sdk import ( "encoding/json" - "errors" "testing" + "google.golang.org/grpc" + "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" @@ -21,18 +22,9 @@ type FakeAccessTokenSource struct { func (fake FakeAccessTokenSource) AccessToken() (auth.AccessToken, error) { return auth.AccessToken(fake.accessToken), nil } -func (fake FakeAccessTokenSource) DecryptWithDPoPKey(encrypted []byte) ([]byte, error) { - return fake.asymDecryption.Decrypt(encrypted) -} func (fake FakeAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) { return tokenMaker(fake.dpopKey) } -func (fake FakeAccessTokenSource) DPoPPublicKeyPEM() string { - return "this is the PEM" -} -func (fake FakeAccessTokenSource) RefreshAccessToken() error { - return errors.New("can't refresh this one") -} func getTokenSource(t *testing.T) FakeAccessTokenSource { dpopKey, _ := crypto.NewRSAKeyPair(2048) @@ -55,8 +47,13 @@ func getTokenSource(t *testing.T) FakeAccessTokenSource { } func TestCreatingRequest(t *testing.T) { + var dialOption []grpc.DialOption tokenSource := getTokenSource(t) - client := KASClient{accessTokenSource: tokenSource} + client, err := newKASClient(dialOption, tokenSource) + if err != nil { + t.Fatalf("error setting KASClient: %v", err) + } + keyAccess := KeyAccess{ KeyType: "type1", KasURL: "https://kas.example.org", @@ -94,9 +91,11 @@ func TestCreatingRequest(t *testing.T) { t.Fatalf("error unmarshaling request body: %v", err) } - if requestBody["clientPublicKey"] != "this is the PEM" { - t.Fatalf("incorrect public key included") + _, err = crypto.NewAsymEncryption(requestBody["clientPublicKey"].(string)) + if err != nil { + t.Fatalf("NewAsymEncryption failed, incorrect public key include: %v", err) } + if requestBody["policy"] != "a policy" { t.Fatalf("incorrect policy") } diff --git a/sdk/sdk.go b/sdk/sdk.go index ce5d05d8ab..1e1977164c 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -66,7 +66,10 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { var unwrapper Unwrapper if cfg.authConfig == nil { - unwrapper = &KASClient{dialOptions: dialOptions, accessTokenSource: accessTokenSource} + unwrapper, err = newKASClient(dialOptions, accessTokenSource) + if err != nil { + return nil, err + } } else { unwrapper = cfg.authConfig }