Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions sdk/kas_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/opentdf/platform/sdk/internal/crypto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
)

Expand All @@ -21,8 +20,8 @@ const (
)

type KASClient struct {
accessTokenSource AccessTokenSource
grpcTransportCredentials credentials.TransportCredentials
accessTokenSource AccessTokenSource
dialOptions []grpc.DialOption
}

type AccessToken string
Expand Down Expand Up @@ -53,14 +52,12 @@ func (k *KASClient) makeRewrapRequest(keyAccess KeyAccess, policy string) (*kas.
return nil, err
}

creds := grpc.WithTransportCredentials(k.grpcTransportCredentials)

grpcAddress, err := getGRPCAddress(keyAccess.KasURL)
if err != nil {
return nil, err
}

conn, err := grpc.Dial(grpcAddress, creds)
conn, err := grpc.Dial(grpcAddress, k.dialOptions...)
if err != nil {
return nil, fmt.Errorf("Error connecting to kas: %w", err)
}
Expand Down Expand Up @@ -161,12 +158,11 @@ func (k *KASClient) getRewrapRequest(keyAccess KeyAccess, policy string) (*kas.R

func (k *KASClient) getPublicKey(kasInfo KASInfo) (string, error) {
req := kas.PublicKeyRequest{}
creds := grpc.WithTransportCredentials(k.grpcTransportCredentials)
grpcAddress, err := getGRPCAddress(kasInfo.url)
if err != nil {
return "", err
}
conn, err := grpc.Dial(grpcAddress, creds)
conn, err := grpc.Dial(grpcAddress, k.dialOptions...)
if err != nil {
return "", fmt.Errorf("error connecting to grpc service at %s: %w", kasInfo.url, err)
}
Expand Down
43 changes: 22 additions & 21 deletions sdk/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sdk

import (
"github.com/opentdf/platform/sdk/internal/oauth"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
Expand All @@ -9,25 +10,15 @@ type Option func(*config)

// Internal config struct for building SDK options.
type config struct {
token grpc.DialOption
clientCredentials grpc.DialOption
tls grpc.DialOption
clientCredentials oauth.ClientCredentials
tokenEndpoint string
scopes []string
authConfig *AuthConfig
}

func (c *config) build() []grpc.DialOption {
var opts []grpc.DialOption

if c.clientCredentials != nil {
opts = append(opts, c.clientCredentials)
}

if c.token != nil {
opts = append(opts, c.token)
}

opts = append(opts, c.tls)

return opts
return []grpc.DialOption{c.tls}
}

// WithInsecureConn returns an Option that sets up an http connection.
Expand All @@ -37,16 +28,26 @@ func WithInsecureConn() Option {
}
}

// WithToken returns an Option that sets up authentication with a access token.
func WithToken(token string) Option {
// WithClientCredentials returns an Option that sets up authentication with client credentials.
func WithClientCredentials(clientID, clientSecret string, scopes []string) Option {
return func(c *config) {
c.token = grpc.WithPerRPCCredentials(nil)
c.clientCredentials = oauth.ClientCredentials{ClientId: clientID, ClientAuth: clientSecret}
c.scopes = scopes
}
}

// WithClientCredentials returns an Option that sets up authentication with client credentials.
func WithClientCredentials(clientID, clientSecret string) Option {
// When we implement service discovery using a .well-known endpoint this option may become deprecated
func WithTokenEndpoint(tokenEndpoint string) Option {
return func(c *config) {
c.tokenEndpoint = tokenEndpoint
}
}

// temporary option to allow the for token exchange and the
// use of REST-ful KASs. this will likely change as we
// make these options more robust
func WithAuthConfig(authConfig AuthConfig) Option {
return func(c *config) {
c.clientCredentials = grpc.WithPerRPCCredentials(nil)
c.authConfig = &authConfig
}
}
53 changes: 52 additions & 1 deletion sdk/sdk.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package sdk

import (
"crypto/tls"
"errors"
"fmt"

"github.com/opentdf/platform/protocol/go/authorization"
"github.com/opentdf/platform/protocol/go/kasregistry"
Expand All @@ -26,6 +28,7 @@ func (c Error) Error() string {

type SDK struct {
conn *grpc.ClientConn
unwrapper Unwrapper
Namespaces namespaces.NamespaceServiceClient
Attributes attributes.AttributesServiceClient
ResourceMapping resourcemapping.ResourceMappingServiceClient
Expand All @@ -35,23 +38,39 @@ type SDK struct {
}

func New(platformEndpoint string, opts ...Option) (*SDK, error) {
tlsConfig := tls.Config{
MinVersion: tls.VersionTLS12,
}

// Set default options
cfg := &config{
tls: grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")),
tls: grpc.WithTransportCredentials(credentials.NewTLS(&tlsConfig)),
}

// Apply options
for _, opt := range opts {
opt(cfg)
}

var unwrapper Unwrapper
if cfg.authConfig == nil {
uw, err := buildKASClient(cfg)
if err != nil {
return nil, err
}
unwrapper = &uw
} else {
unwrapper = cfg.authConfig
}

conn, err := grpc.Dial(platformEndpoint, cfg.build()...)
if err != nil {
return nil, errors.Join(ErrGrpcDialFailed, err)
}

return &SDK{
conn: conn,
unwrapper: unwrapper,
Attributes: attributes.NewAttributesServiceClient(conn),
Namespaces: namespaces.NewNamespaceServiceClient(conn),
ResourceMapping: resourcemapping.NewResourceMappingServiceClient(conn),
Expand All @@ -61,6 +80,38 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {
}, nil
}

func buildKASClient(c *config) (KASClient, error) {
if (c.clientCredentials.ClientId == "") != (c.clientCredentials.ClientAuth == nil) {
return KASClient{},
errors.New("if specifying client credentials must specify both client id and authentication secret")
}
if (c.clientCredentials.ClientId == "") != (c.tokenEndpoint == "") {
return KASClient{}, errors.New("either both or neither of client credentials and token endpoint must be specified")
}

// at this point we have either both client credentials and a token endpoint or none of the above
if c.clientCredentials.ClientId == "" {
return KASClient{}, errors.New("cannot create an SDK with no client credentials")
}

ts, err := NewIDPAccessTokenSource(
c.clientCredentials,
c.tokenEndpoint,
c.scopes,
)

if err != nil {
return KASClient{}, fmt.Errorf("error configuring IDP access: %w", err)
}

kasClient := KASClient{
accessTokenSource: &ts,
dialOptions: c.build(),
}

return kasClient, nil
}

// Close closes the underlying grpc.ClientConn.
func (s SDK) Close() error {
if err := s.conn.Close(); err != nil {
Expand Down
15 changes: 12 additions & 3 deletions sdk/sdk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ func GetMethods(i interface{}) (m []string) {

func Test_ShouldCreateNewSDK(t *testing.T) {
// When
sdk, err := sdk.New(goodPlatformEndpoint)
sdk, err := sdk.New(goodPlatformEndpoint,
sdk.WithClientCredentials("myid", "mysecret", nil),
sdk.WithTokenEndpoint("https://example.org/token"),
)
// Then
if err != nil {
t.Errorf("Expected no error, got %v", err)
Expand All @@ -53,7 +56,10 @@ func Test_ShouldCreateNewSDK(t *testing.T) {
func Test_ShouldCloseSDKConnection(t *testing.T) {
t.Skip("Skipping test since close is broken")
// Given
sdk, err := sdk.New(goodPlatformEndpoint)
sdk, err := sdk.New(goodPlatformEndpoint,
sdk.WithClientCredentials("myid", "mysecret", nil),
sdk.WithTokenEndpoint("https://example.org/token"),
)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand All @@ -66,7 +72,10 @@ func Test_ShouldCloseSDKConnection(t *testing.T) {
}

func Test_ShouldHaveSameMethods(t *testing.T) {
sdk, err := sdk.New(goodPlatformEndpoint)
sdk, err := sdk.New(goodPlatformEndpoint,
sdk.WithClientCredentials("myid", "mysecret", nil),
sdk.WithTokenEndpoint("https://example.org/token"),
)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type Unwrapper interface {
}

// CreateTDF reads plain text from the given reader and saves it to the writer, subject to the given options
func CreateTDF(writer io.Writer, reader io.ReadSeeker, unwrapper Unwrapper, opts ...TDFOption) (*TDFObject, error) { //nolint:funlen, gocognit, lll
func (sdk SDK) CreateTDF(writer io.Writer, reader io.ReadSeeker, opts ...TDFOption) (*TDFObject, error) { //nolint:funlen, gocognit, lll
inputSize, err := reader.Seek(0, io.SeekEnd)
if err != nil {
return nil, fmt.Errorf("readSeeker.Seek failed: %w", err)
Expand All @@ -101,7 +101,7 @@ func CreateTDF(writer io.Writer, reader io.ReadSeeker, unwrapper Unwrapper, opts
return nil, fmt.Errorf("NewTDFConfig failed: %w", err)
}

err = fillInPublicKeys(unwrapper, tdfConfig.kasInfoList)
err = fillInPublicKeys(sdk.unwrapper, tdfConfig.kasInfoList)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -360,7 +360,7 @@ func (t *TDFObject) createPolicyObject(attributes []string) (PolicyObject, error
}

// LoadTDF loads the tdf and prepare for reading the payload from TDF
func LoadTDF(unwrapper Unwrapper, reader io.ReadSeeker) (*Reader, error) {
func (sdk SDK) LoadTDF(reader io.ReadSeeker) (*Reader, error) {
// create tdf reader
tdfReader, err := archive.NewTDFReader(reader)
if err != nil {
Expand All @@ -381,7 +381,7 @@ func LoadTDF(unwrapper Unwrapper, reader io.ReadSeeker) (*Reader, error) {
return &Reader{
tdfReader: tdfReader,
manifest: *manifestObj,
unwrapper: unwrapper,
unwrapper: sdk.unwrapper,
}, nil
}

Expand Down
Loading