diff --git a/sdk/options.go b/sdk/options.go index 7b7cacb7d5..88daabac90 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -22,18 +22,15 @@ type config struct { tokenExchange *oauth.TokenExchangeInfo tokenEndpoint string scopes []string - policyConn *grpc.ClientConn - authorizationConn *grpc.ClientConn - entityresolutionConn *grpc.ClientConn extraDialOptions []grpc.DialOption certExchange *oauth.CertExchangeInfo - wellknownConn *grpc.ClientConn platformConfiguration PlatformConfiguration kasSessionKey *ocrypto.RsaKeyPair dpopKey *ocrypto.RsaKeyPair ipc bool tdfFeatures tdfFeatures customAccessTokenSource auth.AccessTokenSource + coreConn *grpc.ClientConn } // Options specific to TDF protocol features @@ -99,21 +96,24 @@ func withCustomAccessTokenSource(a auth.AccessTokenSource) Option { } } +// Deprecated: Use WithCustomCoreConnection instead func WithCustomPolicyConnection(conn *grpc.ClientConn) Option { return func(c *config) { - c.policyConn = conn + c.coreConn = conn } } +// Deprecated: Use WithCustomCoreConnection instead func WithCustomAuthorizationConnection(conn *grpc.ClientConn) Option { return func(c *config) { - c.authorizationConn = conn + c.coreConn = conn } } +// Deprecated: Use WithCustomCoreConnection instead func WithCustomEntityResolutionConnection(conn *grpc.ClientConn) Option { return func(c *config) { - c.entityresolutionConn = conn + c.coreConn = conn } } @@ -156,7 +156,7 @@ func WithSessionSignerRSA(key *rsa.PrivateKey) Option { func WithCustomWellknownConnection(conn *grpc.ClientConn) Option { return func(c *config) { - c.wellknownConn = conn + c.coreConn = conn } } @@ -183,3 +183,10 @@ func WithNoKIDInKAO() Option { c.tdfFeatures.noKID = true } } + +// WithCoreConnection returns an Option that sets up a connection to the core platform +func WithCustomCoreConnection(conn *grpc.ClientConn) Option { + return func(c *config) { + c.coreConn = conn + } +} diff --git a/sdk/platformconfig.go b/sdk/platformconfig.go index 1b780b0999..572301e40d 100644 --- a/sdk/platformconfig.go +++ b/sdk/platformconfig.go @@ -3,6 +3,14 @@ package sdk import "log/slog" func (s SDK) PlatformIssuer() string { + // This check is needed if we want to fetch platform configuration over ipc + if s.config.platformConfiguration == nil { + cfg, err := getPlatformConfiguration(s.conn) + if err != nil { + slog.Warn("failed to get platform configuration", slog.Any("error", err)) + } + s.config.platformConfiguration = cfg + } value, ok := s.config.platformConfiguration["platform_issuer"].(string) if !ok { slog.Warn("platform_issuer not found in platform configuration") diff --git a/sdk/sdk.go b/sdk/sdk.go index 9eeaf1a645..0e4fd01812 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -66,8 +66,8 @@ type SDK struct { func New(platformEndpoint string, opts ...Option) (*SDK, error) { var ( - defaultConn *grpc.ClientConn // Connection to the platform if no other connection is provided - err error + platformConn *grpc.ClientConn // Connection to the platform + err error ) // Set default options @@ -82,13 +82,12 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { opt(cfg) } - if !cfg.ipc { - platformEndpoint, err = SanitizePlatformEndpoint(platformEndpoint) - if err != nil { - return nil, errors.Join(ErrPlatformEndpointMalformed, err) - } + // If IPC is enabled, we need to have a core connection + if cfg.ipc && cfg.coreConn == nil { + return nil, errors.New("core connection is required for IPC mode") } + // If KAS session key is not provided, generate a new one if cfg.kasSessionKey == nil { key, err := ocrypto.NewRSAKeyPair(tdf3KeySize) if err != nil { @@ -104,22 +103,31 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { dialOptions = append(dialOptions, cfg.extraDialOptions...) } + // IF IPC is disabled we build a connection to the platform + if !cfg.ipc { + platformEndpoint, err = SanitizePlatformEndpoint(platformEndpoint) + if err != nil { + return nil, errors.Join(ErrPlatformEndpointMalformed, err) + } + } + // If platformConfiguration is not provided, fetch it from the platform - if cfg.platformConfiguration == nil && platformEndpoint != "" { //nolint:nestif // Most of checks are for errors + if cfg.platformConfiguration == nil && !cfg.ipc { //nolint:nestif // Most of checks are for errors var pcfg PlatformConfiguration var err error - if cfg.wellknownConn != nil { - pcfg, err = getPlatformConfiguration(cfg.wellknownConn) + if cfg.coreConn != nil { + pcfg, err = getPlatformConfiguration(cfg.coreConn) // Pick a connection until cfg.wellknownConn is removed + if err != nil { + return nil, errors.Join(ErrPlatformConfigFailed, err) + } } else { pcfg, err = fetchPlatformConfiguration(platformEndpoint, dialOptions) - } - - if err != nil { - return nil, errors.Join(ErrPlatformConfigFailed, err) + if err != nil { + return nil, errors.Join(ErrPlatformConfigFailed, err) + } } cfg.platformConfiguration = pcfg - if cfg.tokenEndpoint == "" { cfg.tokenEndpoint, err = getTokenEndpoint(*cfg) if err != nil { @@ -130,6 +138,9 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { var uci []grpc.UnaryClientInterceptor + // Add request ID interceptor + uci = append(uci, audit.MetadataAddingClientInterceptor) + accessTokenSource, err := buildIDPTokenSource(cfg) if err != nil { return nil, err @@ -139,14 +150,13 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { uci = append(uci, interceptor.AddCredentials) } - // Add request ID interceptor - uci = append(uci, audit.MetadataAddingClientInterceptor) - dialOptions = append(dialOptions, grpc.WithChainUnaryInterceptor(uci...)) - if platformEndpoint != "" { - var err error - defaultConn, err = grpc.Dial(platformEndpoint, dialOptions...) + // If coreConn is provided, use it as the platform connection + if cfg.coreConn != nil { + platformConn = cfg.coreConn + } else { + platformConn, err = grpc.Dial(platformEndpoint, dialOptions...) if err != nil { return nil, errors.Join(ErrGrpcDialFailed, err) } @@ -155,18 +165,18 @@ func New(platformEndpoint string, opts ...Option) (*SDK, error) { return &SDK{ config: *cfg, kasKeyCache: newKasKeyCache(), - conn: defaultConn, + conn: platformConn, dialOptions: dialOptions, tokenSource: accessTokenSource, - Attributes: attributes.NewAttributesServiceClient(selectConn(cfg.policyConn, defaultConn)), - Namespaces: namespaces.NewNamespaceServiceClient(selectConn(cfg.policyConn, defaultConn)), - ResourceMapping: resourcemapping.NewResourceMappingServiceClient(selectConn(cfg.policyConn, defaultConn)), - SubjectMapping: subjectmapping.NewSubjectMappingServiceClient(selectConn(cfg.policyConn, defaultConn)), - Unsafe: unsafe.NewUnsafeServiceClient(selectConn(cfg.policyConn, defaultConn)), - KeyAccessServerRegistry: kasregistry.NewKeyAccessServerRegistryServiceClient(selectConn(cfg.policyConn, defaultConn)), - Authorization: authorization.NewAuthorizationServiceClient(selectConn(cfg.authorizationConn, defaultConn)), - EntityResoution: entityresolution.NewEntityResolutionServiceClient(selectConn(cfg.entityresolutionConn, defaultConn)), - wellknownConfiguration: wellknownconfiguration.NewWellKnownServiceClient(selectConn(cfg.wellknownConn, defaultConn)), + Attributes: attributes.NewAttributesServiceClient(platformConn), + Namespaces: namespaces.NewNamespaceServiceClient(platformConn), + ResourceMapping: resourcemapping.NewResourceMappingServiceClient(platformConn), + SubjectMapping: subjectmapping.NewSubjectMappingServiceClient(platformConn), + Unsafe: unsafe.NewUnsafeServiceClient(platformConn), + KeyAccessServerRegistry: kasregistry.NewKeyAccessServerRegistryServiceClient(platformConn), + Authorization: authorization.NewAuthorizationServiceClient(platformConn), + EntityResoution: entityresolution.NewEntityResolutionServiceClient(platformConn), + wellknownConfiguration: wellknownconfiguration.NewWellKnownServiceClient(platformConn), }, nil } @@ -415,12 +425,3 @@ func getTokenEndpoint(c config) (string, error) { return tokenEndpoint, nil } - -// selectConn returns the preferred connection if it is not nil, otherwise it returns the fallback connection -// which is the default connection built to the platform. -func selectConn(preferred, fallback *grpc.ClientConn) *grpc.ClientConn { - if preferred != nil { - return preferred - } - return fallback -} diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index da7e971e72..9e5623a06e 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -156,9 +156,7 @@ func Start(f ...StartOptions) error { if slices.Contains(cfg.Mode, "all") || slices.Contains(cfg.Mode, "core") { // Use IPC for the SDK client sdkOptions = append(sdkOptions, sdk.WithIPC()) - sdkOptions = append(sdkOptions, sdk.WithCustomPolicyConnection(otdf.GRPCInProcess.Conn())) - sdkOptions = append(sdkOptions, sdk.WithCustomAuthorizationConnection(otdf.GRPCInProcess.Conn())) - sdkOptions = append(sdkOptions, sdk.WithCustomEntityResolutionConnection(otdf.GRPCInProcess.Conn())) + sdkOptions = append(sdkOptions, sdk.WithCustomCoreConnection(otdf.GRPCInProcess.Conn())) client, err = sdk.New("", sdkOptions...) if err != nil {