Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions sdk/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@ type config struct {
tokenExchange *oauth.TokenExchangeInfo
tokenEndpoint string
scopes []string
policyConn *grpc.ClientConn
authorizationConn *grpc.ClientConn
entityresolutionConn *grpc.ClientConn
extraDialOptions []grpc.DialOption
certExchange *oauth.CertExchangeInfo
wellknownConn *grpc.ClientConn
platformConfiguration PlatformConfiguration
kasSessionKey *ocrypto.RsaKeyPair
dpopKey *ocrypto.RsaKeyPair
ipc bool
tdfFeatures tdfFeatures
customAccessTokenSource auth.AccessTokenSource
coreConn *grpc.ClientConn
}

// Options specific to TDF protocol features
Expand Down Expand Up @@ -99,21 +96,24 @@ func withCustomAccessTokenSource(a auth.AccessTokenSource) Option {
}
}

// Deprecated: Use WithCustomCoreConnection instead
func WithCustomPolicyConnection(conn *grpc.ClientConn) Option {
return func(c *config) {
c.policyConn = conn
c.coreConn = conn
}
}

// Deprecated: Use WithCustomCoreConnection instead
func WithCustomAuthorizationConnection(conn *grpc.ClientConn) Option {
return func(c *config) {
c.authorizationConn = conn
c.coreConn = conn
}
}

// Deprecated: Use WithCustomCoreConnection instead
func WithCustomEntityResolutionConnection(conn *grpc.ClientConn) Option {
return func(c *config) {
c.entityresolutionConn = conn
c.coreConn = conn
}
}

Expand Down Expand Up @@ -156,7 +156,7 @@ func WithSessionSignerRSA(key *rsa.PrivateKey) Option {

func WithCustomWellknownConnection(conn *grpc.ClientConn) Option {
return func(c *config) {
c.wellknownConn = conn
c.coreConn = conn
}
}

Expand All @@ -183,3 +183,10 @@ func WithNoKIDInKAO() Option {
c.tdfFeatures.noKID = true
}
}

// WithCoreConnection returns an Option that sets up a connection to the core platform
func WithCustomCoreConnection(conn *grpc.ClientConn) Option {
return func(c *config) {
c.coreConn = conn
}
}
8 changes: 8 additions & 0 deletions sdk/platformconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ package sdk
import "log/slog"

func (s SDK) PlatformIssuer() string {
// This check is needed if we want to fetch platform configuration over ipc
if s.config.platformConfiguration == nil {
cfg, err := getPlatformConfiguration(s.conn)
if err != nil {
slog.Warn("failed to get platform configuration", slog.Any("error", err))
}
s.config.platformConfiguration = cfg
}
value, ok := s.config.platformConfiguration["platform_issuer"].(string)
if !ok {
slog.Warn("platform_issuer not found in platform configuration")
Expand Down
81 changes: 41 additions & 40 deletions sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ type SDK struct {

func New(platformEndpoint string, opts ...Option) (*SDK, error) {
var (
defaultConn *grpc.ClientConn // Connection to the platform if no other connection is provided
err error
platformConn *grpc.ClientConn // Connection to the platform
err error
)

// Set default options
Expand All @@ -82,13 +82,12 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {
opt(cfg)
}

if !cfg.ipc {
platformEndpoint, err = SanitizePlatformEndpoint(platformEndpoint)
if err != nil {
return nil, errors.Join(ErrPlatformEndpointMalformed, err)
}
// If IPC is enabled, we need to have a core connection
if cfg.ipc && cfg.coreConn == nil {
return nil, errors.New("core connection is required for IPC mode")
}

// If KAS session key is not provided, generate a new one
if cfg.kasSessionKey == nil {
key, err := ocrypto.NewRSAKeyPair(tdf3KeySize)
if err != nil {
Expand All @@ -104,22 +103,31 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {
dialOptions = append(dialOptions, cfg.extraDialOptions...)
}

// IF IPC is disabled we build a connection to the platform
if !cfg.ipc {
platformEndpoint, err = SanitizePlatformEndpoint(platformEndpoint)
if err != nil {
return nil, errors.Join(ErrPlatformEndpointMalformed, err)
}
}

// If platformConfiguration is not provided, fetch it from the platform
if cfg.platformConfiguration == nil && platformEndpoint != "" { //nolint:nestif // Most of checks are for errors
if cfg.platformConfiguration == nil && !cfg.ipc { //nolint:nestif // Most of checks are for errors
var pcfg PlatformConfiguration
var err error

if cfg.wellknownConn != nil {
pcfg, err = getPlatformConfiguration(cfg.wellknownConn)
if cfg.coreConn != nil {
pcfg, err = getPlatformConfiguration(cfg.coreConn) // Pick a connection until cfg.wellknownConn is removed
if err != nil {
return nil, errors.Join(ErrPlatformConfigFailed, err)
}
} else {
pcfg, err = fetchPlatformConfiguration(platformEndpoint, dialOptions)
}

if err != nil {
return nil, errors.Join(ErrPlatformConfigFailed, err)
if err != nil {
return nil, errors.Join(ErrPlatformConfigFailed, err)
}
}
cfg.platformConfiguration = pcfg

if cfg.tokenEndpoint == "" {
cfg.tokenEndpoint, err = getTokenEndpoint(*cfg)
if err != nil {
Expand All @@ -130,6 +138,9 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {

var uci []grpc.UnaryClientInterceptor

// Add request ID interceptor
uci = append(uci, audit.MetadataAddingClientInterceptor)

accessTokenSource, err := buildIDPTokenSource(cfg)
if err != nil {
return nil, err
Expand All @@ -139,14 +150,13 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {
uci = append(uci, interceptor.AddCredentials)
}

// Add request ID interceptor
uci = append(uci, audit.MetadataAddingClientInterceptor)

dialOptions = append(dialOptions, grpc.WithChainUnaryInterceptor(uci...))

if platformEndpoint != "" {
var err error
defaultConn, err = grpc.Dial(platformEndpoint, dialOptions...)
// If coreConn is provided, use it as the platform connection
if cfg.coreConn != nil {
platformConn = cfg.coreConn
} else {
platformConn, err = grpc.Dial(platformEndpoint, dialOptions...)
if err != nil {
return nil, errors.Join(ErrGrpcDialFailed, err)
}
Expand All @@ -155,18 +165,18 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) {
return &SDK{
config: *cfg,
kasKeyCache: newKasKeyCache(),
conn: defaultConn,
conn: platformConn,
dialOptions: dialOptions,
tokenSource: accessTokenSource,
Attributes: attributes.NewAttributesServiceClient(selectConn(cfg.policyConn, defaultConn)),
Namespaces: namespaces.NewNamespaceServiceClient(selectConn(cfg.policyConn, defaultConn)),
ResourceMapping: resourcemapping.NewResourceMappingServiceClient(selectConn(cfg.policyConn, defaultConn)),
SubjectMapping: subjectmapping.NewSubjectMappingServiceClient(selectConn(cfg.policyConn, defaultConn)),
Unsafe: unsafe.NewUnsafeServiceClient(selectConn(cfg.policyConn, defaultConn)),
KeyAccessServerRegistry: kasregistry.NewKeyAccessServerRegistryServiceClient(selectConn(cfg.policyConn, defaultConn)),
Authorization: authorization.NewAuthorizationServiceClient(selectConn(cfg.authorizationConn, defaultConn)),
EntityResoution: entityresolution.NewEntityResolutionServiceClient(selectConn(cfg.entityresolutionConn, defaultConn)),
wellknownConfiguration: wellknownconfiguration.NewWellKnownServiceClient(selectConn(cfg.wellknownConn, defaultConn)),
Attributes: attributes.NewAttributesServiceClient(platformConn),
Namespaces: namespaces.NewNamespaceServiceClient(platformConn),
ResourceMapping: resourcemapping.NewResourceMappingServiceClient(platformConn),
SubjectMapping: subjectmapping.NewSubjectMappingServiceClient(platformConn),
Unsafe: unsafe.NewUnsafeServiceClient(platformConn),
KeyAccessServerRegistry: kasregistry.NewKeyAccessServerRegistryServiceClient(platformConn),
Authorization: authorization.NewAuthorizationServiceClient(platformConn),
EntityResoution: entityresolution.NewEntityResolutionServiceClient(platformConn),
wellknownConfiguration: wellknownconfiguration.NewWellKnownServiceClient(platformConn),
}, nil
}

Expand Down Expand Up @@ -415,12 +425,3 @@ func getTokenEndpoint(c config) (string, error) {

return tokenEndpoint, nil
}

// selectConn returns the preferred connection if it is not nil, otherwise it returns the fallback connection
// which is the default connection built to the platform.
func selectConn(preferred, fallback *grpc.ClientConn) *grpc.ClientConn {
if preferred != nil {
return preferred
}
return fallback
}
4 changes: 1 addition & 3 deletions service/pkg/server/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ func Start(f ...StartOptions) error {
if slices.Contains(cfg.Mode, "all") || slices.Contains(cfg.Mode, "core") {
// Use IPC for the SDK client
sdkOptions = append(sdkOptions, sdk.WithIPC())
sdkOptions = append(sdkOptions, sdk.WithCustomPolicyConnection(otdf.GRPCInProcess.Conn()))
sdkOptions = append(sdkOptions, sdk.WithCustomAuthorizationConnection(otdf.GRPCInProcess.Conn()))
sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(otdf.GRPCInProcess.Conn()))
sdkOptions = append(sdkOptions, sdk.WithCustomCoreConnection(otdf.GRPCInProcess.Conn()))

client, err = sdk.New("", sdkOptions...)
if err != nil {
Expand Down