Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions internal/auth/authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,19 @@ type FakeTokenSource struct {
func (fts *FakeTokenSource) AccessToken() (sdkauth.AccessToken, error) {
return sdkauth.AccessToken(fts.accessToken), nil
}
func (*FakeTokenSource) DecryptWithDPoPKey([]byte) ([]byte, error) {
return nil, nil
}
func (fts *FakeTokenSource) MakeToken(f func(jwk.Key) ([]byte, error)) ([]byte, error) {
if fts.key == nil {
return nil, errors.New("no such key")
}
return f(fts.key)
}
func (*FakeTokenSource) DPoPPublicKeyPEM() string {
return ""
}
func (*FakeTokenSource) RefreshAccessToken() error {
return nil
}

func (fake FakeAccessTokenSource) AccessToken() (sdkauth.AccessToken, error) {
return sdkauth.AccessToken(fake.accessToken), nil
}
func (fake FakeAccessTokenSource) DecryptWithDPoPKey(_ []byte) ([]byte, error) {
return nil, nil
}
func (fake FakeAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(fake.dpopKey)
}
func (fake FakeAccessTokenSource) DPoPPublicKeyPEM() string {
return "this is the PEM"
}
func (fake FakeAccessTokenSource) RefreshAccessToken() error {
return errors.New("can't refresh this one")
}

func must(err error) {
if err != nil {
Expand Down
2 changes: 0 additions & 2 deletions sdk/auth/access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ type AccessTokenSource interface {
AccessToken() (AccessToken, error)
// probably better to use `crypto.AsymDecryption` here than roll our own since this should be
// more closely linked to what happens in KAS in terms of crypto params
DecryptWithDPoPKey(data []byte) ([]byte, error)
MakeToken(func(jwk.Key) ([]byte, error)) ([]byte, error)
DPoPPublicKeyPEM() string
}
10 changes: 0 additions & 10 deletions sdk/auth/token_adding_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,12 @@ type FakeTokenSource struct {
func (fts *FakeTokenSource) AccessToken() (AccessToken, error) {
return AccessToken(fts.accessToken), nil
}
func (*FakeTokenSource) DecryptWithDPoPKey([]byte) ([]byte, error) {
return nil, nil
}
func (fts *FakeTokenSource) MakeToken(f func(jwk.Key) ([]byte, error)) ([]byte, error) {
if fts.key == nil {
return nil, errors.New("no such key")
}
return f(fts.key)
}
func (*FakeTokenSource) DPoPPublicKeyPEM() string {
return ""
}
func (*FakeTokenSource) RefreshAccessToken() error {
return nil
}

func runServer(ctx context.Context, //nolint:ireturn // this is pretty concrete
f *FakeAccessServiceServer, oo TokenAddingInterceptor) (kas.AccessServiceClient, func()) {
buffer := 1024 * 1024
Expand Down
8 changes: 0 additions & 8 deletions sdk/idp_access_token_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,6 @@ func (t *IDPAccessTokenSource) AccessToken() (auth.AccessToken, error) {
return auth.AccessToken(t.token.AccessToken), nil
}

func (t *IDPAccessTokenSource) DecryptWithDPoPKey(data []byte) ([]byte, error) {
return t.asymDecryption.Decrypt(data)
}

func (t *IDPAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(t.dpopKey)
}

func (t *IDPAccessTokenSource) DPoPPublicKeyPEM() string {
return t.dpopPEM
}
40 changes: 36 additions & 4 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
kas "github.com/opentdf/platform/protocol/go/kas"
"github.com/opentdf/platform/sdk/auth"
"github.com/opentdf/platform/sdk/internal/crypto"
"google.golang.org/grpc"
)

Expand All @@ -19,8 +20,10 @@ const (
)

type KASClient struct {
accessTokenSource auth.AccessTokenSource
dialOptions []grpc.DialOption
accessTokenSource auth.AccessTokenSource
dialOptions []grpc.DialOption
clientPublicKeyPEM string
asymDecryption crypto.AsymDecryption
}

// once the backend moves over we should use the same type that the golang backend uses here
Expand All @@ -32,6 +35,35 @@ type rewrapRequestBody struct {
SchemaVersion string `json:"schemaVersion,omitempty"`
}

func newKASClient(dialOptions []grpc.DialOption, accessTokenSource auth.AccessTokenSource) (*KASClient, error) {
rsaKeyPair, err := crypto.NewRSAKeyPair(tdf3KeySize)
if err != nil {
return nil, fmt.Errorf("crypto.NewRSAKeyPair failed: %w", err)
}

clientPublicKey, err := rsaKeyPair.PublicKeyInPemFormat()
if err != nil {
return nil, fmt.Errorf("crypto.PublicKeyInPemFormat failed: %w", err)
}

clientPrivateKey, err := rsaKeyPair.PrivateKeyInPemFormat()
if err != nil {
return nil, fmt.Errorf("crypto.PrivateKeyInPemFormat failed: %w", err)
}

asymDecryption, err := crypto.NewAsymDecryption(clientPrivateKey)
if err != nil {
return nil, fmt.Errorf("crypto.NewAsymDecryption failed: %w", err)
}

return &KASClient{
accessTokenSource: accessTokenSource,
dialOptions: dialOptions,
clientPublicKeyPEM: clientPublicKey,
asymDecryption: asymDecryption,
}, nil
}

// there is no connection caching as of now
func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.RewrapResponse, error) {
rewrapRequest, err := k.getRewrapRequest(keyAccess, policy)
Expand Down Expand Up @@ -64,7 +96,7 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.
func (k *KASClient) unwrap(keyAccess KeyAccess, policy string) ([]byte, error) {
response, err := k.makeRewrapRequest(keyAccess, policy)

key, err := k.accessTokenSource.DecryptWithDPoPKey(response.GetEntityWrappedKey())
key, err := k.asymDecryption.Decrypt(response.GetEntityWrappedKey())
if err != nil {
return nil, fmt.Errorf("error decrypting payload from KAS: %w", err)
}
Expand Down Expand Up @@ -92,7 +124,7 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R
requestBody := rewrapRequestBody{
Policy: policy,
KeyAccess: keyAccess,
ClientPublicKey: k.accessTokenSource.DPoPPublicKeyPEM(),
ClientPublicKey: k.clientPublicKeyPEM,
}
requestBodyJSON, err := json.Marshal(requestBody)
if err != nil {
Expand Down
25 changes: 12 additions & 13 deletions sdk/kas_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package sdk

import (
"encoding/json"
"errors"
"testing"

"google.golang.org/grpc"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
Expand All @@ -21,18 +22,9 @@ type FakeAccessTokenSource struct {
func (fake FakeAccessTokenSource) AccessToken() (auth.AccessToken, error) {
return auth.AccessToken(fake.accessToken), nil
}
func (fake FakeAccessTokenSource) DecryptWithDPoPKey(encrypted []byte) ([]byte, error) {
return fake.asymDecryption.Decrypt(encrypted)
}
func (fake FakeAccessTokenSource) MakeToken(tokenMaker func(jwk.Key) ([]byte, error)) ([]byte, error) {
return tokenMaker(fake.dpopKey)
}
func (fake FakeAccessTokenSource) DPoPPublicKeyPEM() string {
return "this is the PEM"
}
func (fake FakeAccessTokenSource) RefreshAccessToken() error {
return errors.New("can't refresh this one")
}

func getTokenSource(t *testing.T) FakeAccessTokenSource {
dpopKey, _ := crypto.NewRSAKeyPair(2048)
Expand All @@ -55,8 +47,13 @@ func getTokenSource(t *testing.T) FakeAccessTokenSource {
}

func TestCreatingRequest(t *testing.T) {
var dialOption []grpc.DialOption
tokenSource := getTokenSource(t)
client := KASClient{accessTokenSource: tokenSource}
client, err := newKASClient(dialOption, tokenSource)
if err != nil {
t.Fatalf("error setting KASClient: %v", err)
}

keyAccess := KeyAccess{
KeyType: "type1",
KasURL: "https://kas.example.org",
Expand Down Expand Up @@ -94,9 +91,11 @@ func TestCreatingRequest(t *testing.T) {
t.Fatalf("error unmarshaling request body: %v", err)
}

if requestBody["clientPublicKey"] != "this is the PEM" {
t.Fatalf("incorrect public key included")
_, err = crypto.NewAsymEncryption(requestBody["clientPublicKey"].(string))
if err != nil {
t.Fatalf("NewAsymEncryption failed, incorrect public key include: %v", err)
}

if requestBody["policy"] != "a policy" {
t.Fatalf("incorrect policy")
}
Expand Down
5 changes: 4 additions & 1 deletion sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {

var unwrapper Unwrapper
if cfg.authConfig == nil {
unwrapper = &KASClient{dialOptions: dialOptions, accessTokenSource: accessTokenSource}
unwrapper, err = newKASClient(dialOptions, accessTokenSource)
if err != nil {
return nil, err
}
} else {
unwrapper = cfg.authConfig
}
Expand Down