Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions service/pkg/server/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"context"

"github.com/casbin/casbin/v2/persist"
"github.com/opentdf/platform/service/pkg/config"
"github.com/opentdf/platform/service/pkg/serviceregistry"
Expand All @@ -22,7 +24,7 @@ type StartConfig struct {
configLoaders []config.Loader
configLoaderOrder []string

trustKeyManagers []trust.NamedKeyManagerFactory
trustKeyManagerCtxs []trust.NamedKeyManagerCtxFactory
}

// Deprecated: Use WithConfigKey
Expand Down Expand Up @@ -137,9 +139,25 @@ func WithConfigLoaderOrder(loaderOrder []string) StartOptions {
}

// WithTrustKeyManagerFactories option provides factories for creating trust key managers.
// Deprecated: Use WithTrustKeyManagerCtxFactories
func WithTrustKeyManagerFactories(factories ...trust.NamedKeyManagerFactory) StartOptions {
return func(c StartConfig) StartConfig {
c.trustKeyManagers = append(c.trustKeyManagers, factories...)
for _, factory := range factories {
c.trustKeyManagerCtxs = append(c.trustKeyManagerCtxs, trust.NamedKeyManagerCtxFactory{
Name: factory.Name,
Factory: func(_ context.Context, opts *trust.KeyManagerFactoryOptions) (trust.KeyManager, error) {
return factory.Factory(opts)
},
})
}
return c
}
}

// WithTrustKeyManagerFactories option provides factories for creating trust key managers.
func WithTrustKeyManagerCtxFactories(factories ...trust.NamedKeyManagerCtxFactory) StartOptions {
return func(c StartConfig) StartConfig {
c.trustKeyManagerCtxs = append(c.trustKeyManagerCtxs, factories...)
return c
}
}
29 changes: 21 additions & 8 deletions service/pkg/server/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ func registerCoreServices(reg *serviceregistry.Registry, mode []string) ([]strin
}

type startServicesParams struct {
cfg *config.Config
otdf *server.OpenTDFServer
client *sdk.SDK
logger *logging.Logger
reg *serviceregistry.Registry
cacheManager *cache.Manager
keyManagerFactories []trust.NamedKeyManagerFactory
cfg *config.Config
otdf *server.OpenTDFServer
client *sdk.SDK
logger *logging.Logger
reg *serviceregistry.Registry
cacheManager *cache.Manager
keyManagerCtxFactories []trust.NamedKeyManagerCtxFactory
}

// startServices iterates through the registered namespaces and starts the services
Expand All @@ -141,7 +141,19 @@ func startServices(ctx context.Context, params startServicesParams) (func(), err
logger := params.logger
reg := params.reg
cacheManager := params.cacheManager
keyManagerFactories := params.keyManagerFactories
keyManagerCtxFactories := params.keyManagerCtxFactories

// Create a copy of the key manager factories as the context version for legacy services that don't load the new version with context
var keyManagerFactories []trust.NamedKeyManagerFactory
for _, factory := range keyManagerCtxFactories {
keyManagerFactories = append(keyManagerFactories, trust.NamedKeyManagerFactory{
Name: factory.Name,
//nolint:contextcheck // This is called later, so will be in a new context
Factory: func(opts *trust.KeyManagerFactoryOptions) (trust.KeyManager, error) {
return factory.Factory(context.Background(), opts)
},
})
}

for _, ns := range reg.GetNamespaces() {
namespace, err := reg.GetNamespace(ns)
Expand Down Expand Up @@ -225,6 +237,7 @@ func startServices(ctx context.Context, params startServicesParams) (func(), err
Tracer: tracer,
NewCacheClient: createCacheClient,
KeyManagerFactories: keyManagerFactories,
KeyManagerCtxFactories: keyManagerCtxFactories,
})
if err != nil {
return func() {}, err
Expand Down
10 changes: 4 additions & 6 deletions service/pkg/server/services_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/opentdf/platform/service/logger"
"github.com/opentdf/platform/service/pkg/config"
"github.com/opentdf/platform/service/pkg/serviceregistry"
"github.com/opentdf/platform/service/trust"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -276,11 +275,10 @@ func (suite *ServiceTestSuite) TestStartServicesWithVariousCases() {
"foobar": {},
},
},
otdf: otdf,
client: nil,
keyManagerFactories: []trust.NamedKeyManagerFactory{},
logger: newLogger,
reg: registry,
otdf: otdf,
client: nil,
logger: newLogger,
reg: registry,
})

// call cleanup function
Expand Down
14 changes: 7 additions & 7 deletions service/pkg/server/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ func Start(f ...StartOptions) error {

logger.Info("starting services")
gatewayCleanup, err := startServices(ctx, startServicesParams{
cfg: cfg,
otdf: otdf,
client: client,
keyManagerFactories: startConfig.trustKeyManagers,
logger: logger,
reg: svcRegistry,
cacheManager: cacheManager,
cfg: cfg,
otdf: otdf,
client: client,
keyManagerCtxFactories: startConfig.trustKeyManagerCtxs,
logger: logger,
reg: svcRegistry,
cacheManager: cacheManager,
})
if err != nil {
logger.Error("issue starting services", slog.String("error", err.Error()))
Expand Down
12 changes: 6 additions & 6 deletions service/pkg/server/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,12 @@ func (s *StartTestSuite) Test_Start_When_Extra_Service_Registered() {
"test": {},
},
},
otdf: s,
client: nil,
keyManagerFactories: []trust.NamedKeyManagerFactory{},
logger: logger,
reg: registry,
cacheManager: &cache.Manager{},
otdf: s,
client: nil,
keyManagerCtxFactories: []trust.NamedKeyManagerCtxFactory{},
logger: logger,
reg: registry,
cacheManager: &cache.Manager{},
})
require.NoError(t, err)
defer cleanup()
Expand Down
7 changes: 7 additions & 0 deletions service/pkg/serviceregistry/serviceregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@ type RegistrationParams struct {
// NewCacheClient is a function that can be used to create a new cache instance for the service
NewCacheClient func(cache.Options) (*cache.Cache, error)

// KeyManagerFactories are the registered key manager factories that can be used to create
// key managers for the service to use.
// Deprecated: Use KeyManagerCtxFactories
KeyManagerFactories []trust.NamedKeyManagerFactory

// KeyManagerCtxFactories are the registered key manager context factories that can be used to create
// key managers for the service to use.
KeyManagerCtxFactories []trust.NamedKeyManagerCtxFactory

////// The following functions are optional and intended to be called by the service //////

// RegisterWellKnownConfig is a function that can be used to register a well-known configuration
Expand Down
31 changes: 21 additions & 10 deletions service/trust/delegating_key_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ type KeyManagerFactoryOptions struct {
// KeyManagerFactory defines the signature for functions that can create KeyManager instances.
type KeyManagerFactory func(opts *KeyManagerFactoryOptions) (KeyManager, error)

// KeyManagerFactoryCtx defines the signature for functions that can create KeyManager instances.
type KeyManagerFactoryCtx func(ctx context.Context, opts *KeyManagerFactoryOptions) (KeyManager, error)

// DelegatingKeyService is a key service that multiplexes between key managers based on the key's mode.
type DelegatingKeyService struct {
// Lookup key manager by mode for a given key identifier
index KeyIndex

// Lazily create key managers based on their mode
managerFactories map[string]KeyManagerFactory
managerFactories map[string]KeyManagerFactoryCtx

// Cache of key managers to avoid creating them multiple times
managers map[string]KeyManager
Expand All @@ -46,14 +49,22 @@ type DelegatingKeyService struct {
func NewDelegatingKeyService(index KeyIndex, l *logger.Logger, c *cache.Cache) *DelegatingKeyService {
return &DelegatingKeyService{
index: index,
managerFactories: make(map[string]KeyManagerFactory),
managerFactories: make(map[string]KeyManagerFactoryCtx),
managers: make(map[string]KeyManager),
l: l,
c: c,
}
}

func (d *DelegatingKeyService) RegisterKeyManager(name string, factory KeyManagerFactory) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.managerFactories[name] = func(_ context.Context, opts *KeyManagerFactoryOptions) (KeyManager, error) {
return factory(opts)
}
}

func (d *DelegatingKeyService) RegisterKeyManagerCtx(name string, factory KeyManagerFactoryCtx) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.managerFactories[name] = factory
Expand Down Expand Up @@ -93,7 +104,7 @@ func (d *DelegatingKeyService) Decrypt(ctx context.Context, keyID KeyIdentifier,
return nil, fmt.Errorf("unable to find key by ID '%s' within index %s: %w", keyID, d.index, err)
}

manager, err := d.getKeyManager(keyDetails.System())
manager, err := d.getKeyManager(ctx, keyDetails.System())
if err != nil {
return nil, fmt.Errorf("unable to get key manager for system '%s': %w", keyDetails.System(), err)
}
Expand All @@ -107,7 +118,7 @@ func (d *DelegatingKeyService) DeriveKey(ctx context.Context, keyID KeyIdentifie
return nil, fmt.Errorf("unable to find key by ID '%s' in index %s: %w", keyID, d.index, err)
}

manager, err := d.getKeyManager(keyDetails.System())
manager, err := d.getKeyManager(ctx, keyDetails.System())
if err != nil {
return nil, fmt.Errorf("unable to get key manager for system '%s': %w", keyDetails.System(), err)
}
Expand All @@ -117,7 +128,7 @@ func (d *DelegatingKeyService) DeriveKey(ctx context.Context, keyID KeyIdentifie

func (d *DelegatingKeyService) GenerateECSessionKey(ctx context.Context, ephemeralPublicKey string) (Encapsulator, error) {
// Assuming a default manager for session key generation
manager, err := d._defKM()
manager, err := d._defKM(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get default key manager: %w", err)
}
Expand All @@ -137,7 +148,7 @@ func (d *DelegatingKeyService) Close() {

// _defKM retrieves or initializes the default KeyManager.
// It uses the `defaultMode` field to determine which manager is the default.
func (d *DelegatingKeyService) _defKM() (KeyManager, error) {
func (d *DelegatingKeyService) _defKM(ctx context.Context) (KeyManager, error) {
d.mutex.Lock()
// Check if defaultKeyManager is already cached
if d.defaultKeyManager == nil {
Expand All @@ -154,7 +165,7 @@ func (d *DelegatingKeyService) _defKM() (KeyManager, error) {
// This call to getKeyManager will handle its own locking and,
// due to the check `if name == currentDefaultMode` in getKeyManager,
// will error out if `defaultModeName` itself is not found, preventing recursion.
manager, err := d.getKeyManager(defaultModeName)
manager, err := d.getKeyManager(ctx, defaultModeName)
if err != nil {
return nil, fmt.Errorf("failed to get default key manager for mode '%s': %w", defaultModeName, err)
}
Expand All @@ -172,7 +183,7 @@ func (d *DelegatingKeyService) _defKM() (KeyManager, error) {
return managerToReturn, nil
}

func (d *DelegatingKeyService) getKeyManager(name string) (KeyManager, error) {
func (d *DelegatingKeyService) getKeyManager(ctx context.Context, name string) (KeyManager, error) {
d.mutex.Lock()

// Check For Manager First
Expand All @@ -189,7 +200,7 @@ func (d *DelegatingKeyService) getKeyManager(name string) (KeyManager, error) {

if factoryExists {
options := &KeyManagerFactoryOptions{Logger: d.l.With("key-manager", name), Cache: d.c}
managerFromFactory, err := factory(options)
managerFromFactory, err := factory(ctx, options)
if err != nil {
return nil, fmt.Errorf("factory for key manager '%s' failed: %w", name, err)
}
Expand All @@ -211,5 +222,5 @@ func (d *DelegatingKeyService) getKeyManager(name string) (KeyManager, error) {
slog.String("requested_name", name),
slog.String("configured_default_mode", currentDefaultMode),
)
return d._defKM() // _defKM handles erroring if the default manager itself cannot be loaded.
return d._defKM(ctx) // _defKM handles erroring if the default manager itself cannot be loaded.
}
6 changes: 6 additions & 0 deletions service/trust/key_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,9 @@ type NamedKeyManagerFactory struct {
Name string
Factory KeyManagerFactory
}

// NamedKeyManagerCtxFactory pairs a KeyManagerFactoryCtx with its intended registration name.
type NamedKeyManagerCtxFactory struct {
Name string
Factory KeyManagerFactoryCtx
}
Loading