From e194a1768a3722ef702e23716037ef09a67fbc88 Mon Sep 17 00:00:00 2001 From: Kaia Lang Date: Mon, 12 May 2025 10:20:04 -0700 Subject: [PATCH 1/3] feat(core): define module registry and lifecycle hooks --- .../custom_modules/database_module.go | 28 ++ .../modules_v1/module_lifecycle_test.go | 36 ++ router/core/hooks.go | 264 +++++++++++ router/core/hooks_context.go | 75 ++++ router/core/hooks_test.go | 70 +++ router/core/modules_v1.go | 232 ++++++++++ router/core/modules_v1_test.go | 409 ++++++++++++++++++ router/core/router.go | 44 ++ router/core/router_config.go | 1 + router/internal/utils/ordered_set.go | 61 +++ router/internal/utils/ptrs.go | 6 + 11 files changed, 1226 insertions(+) create mode 100644 router-tests/modules_v1/custom_modules/database_module.go create mode 100644 router-tests/modules_v1/module_lifecycle_test.go create mode 100644 router/core/hooks.go create mode 100644 router/core/hooks_context.go create mode 100644 router/core/hooks_test.go create mode 100644 router/core/modules_v1.go create mode 100644 router/core/modules_v1_test.go create mode 100644 router/internal/utils/ordered_set.go create mode 100644 router/internal/utils/ptrs.go diff --git a/router-tests/modules_v1/custom_modules/database_module.go b/router-tests/modules_v1/custom_modules/database_module.go new file mode 100644 index 0000000000..b475be5b98 --- /dev/null +++ b/router-tests/modules_v1/custom_modules/database_module.go @@ -0,0 +1,28 @@ +package custom_modules + +import ( + "github.com/wundergraph/cosmo/router/core" +) + +type DatabaseModule struct {} + +func (m *DatabaseModule) Module() core.ModuleV1Info { + priority := 2 + return core.ModuleV1Info{ + ID: "database_module", + Priority: &priority, + New: func() core.ModuleV1 { + return &DatabaseModule{} + }, + } +} + +func (m *DatabaseModule) Provision(ctx *core.ModuleV1Context) error { + ctx.Logger.Info("Database module provisioned") + return nil +} + +func (m *DatabaseModule) Cleanup(ctx *core.ModuleV1Context) error { + ctx.Logger.Info("Database module cleaned up") + return nil +} diff --git a/router-tests/modules_v1/module_lifecycle_test.go b/router-tests/modules_v1/module_lifecycle_test.go new file mode 100644 index 0000000000..dbe8b76abb --- /dev/null +++ b/router-tests/modules_v1/module_lifecycle_test.go @@ -0,0 +1,36 @@ +package modules_v1 + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/modules_v1/custom_modules" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" +) + +func TestModuleV1ProvisionAndCleanupLifecycle(t *testing.T) { + t.Parallel() + + t.Run("no regression with the module system introduced", func(t *testing.T) { + t.Parallel() + + dbModule := &custom_modules.DatabaseModule{} + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithCustomModulesV1(dbModule), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + assert.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body) + }) + }) +} diff --git a/router/core/hooks.go b/router/core/hooks.go new file mode 100644 index 0000000000..2f781ed59c --- /dev/null +++ b/router/core/hooks.go @@ -0,0 +1,264 @@ +package core + +import ( + "github.com/wundergraph/cosmo/router/internal/utils" + "go.uber.org/zap" +) + +// Application Lifecycle Hooks +type ApplicationLifecycleHook interface { + ApplicationStartHook + ApplicationStopHook +} + +type ApplicationStartHook interface { + OnApplicationStart(ctx ApplicationStartHookContext) error +} + +type ApplicationStopHook interface { + OnApplicationStop(ctx ApplicationStopHookContext) error +} + +// GraphQL Server Lifecycle Hooks +type GraphQLServerLifecycleHook interface { + GraphQLServerStartHook + GraphQLServerStopHook +} + +type GraphQLServerStartHook interface { + OnGraphQLServerStart(ctx GraphQLServerStartHookContext) error +} + +type GraphQLServerStopHook interface { + OnGraphQLServerStop(ctx GraphQLServerStopHookContext) error +} + +// Router Lifecycle Hooks +type RouterRequestHook interface { + OnRouterRequest(ctx RouterRequestHookContext) error +} + +type RouterResponseHook interface { + OnRouterResponse(ctx RouterResponseHookContext) error +} + +type RouterLifecycleHook interface { + RouterRequestHook + RouterResponseHook +} + +// Subgraph Lifecycle Hooks +type SubgraphRequestHook interface { + OnSubgraphRequest(ctx SubgraphRequestHookContext) error +} + +type SubgraphResponseHook interface { + OnSubgraphResponse(ctx SubgraphResponseHookContext) error +} + +type SubgraphLifecycleHook interface { + SubgraphRequestHook + SubgraphResponseHook +} + +// Operation Lifecycle Hooks +type OperationLifecycleHook interface { + OperationParseLifecycleHook + OperationNormalizeLifecycleHook + OperationValidateLifecycleHook + OperationPlanLifecycleHook + OperationExecuteLifecycleHook +} + +type OperationParseLifecycleHook interface { + OperationPreParseHook + OperationPostParseHook +} + +type OperationPreParseHook interface { + OnOperationPreParse(ctx OperationPreParseHookContext) error +} + +type OperationPostParseHook interface { + OnOperationPostParse(ctx OperationPostParseHookContext) error +} + +type OperationNormalizeLifecycleHook interface { + OperationPreNormalizeHook + OperationPostNormalizeHook +} + +type OperationPreNormalizeHook interface { + OnOperationPreNormalize(ctx OperationPreNormalizeHookContext) error +} + +type OperationPostNormalizeHook interface { + OnOperationPostNormalize(ctx OperationPostNormalizeHookContext) error +} + +type OperationValidateLifecycleHook interface { + OperationPreValidateHook + OperationPostValidateHook +} + +type OperationPreValidateHook interface { + OnOperationPreValidate(ctx OperationPreValidateHookContext) error +} + +type OperationPostValidateHook interface { + OnOperationPostValidate(ctx OperationPostValidateHookContext) error +} + +type OperationPlanLifecycleHook interface { + OperationPrePlanHook + OperationPostPlanHook +} + +type OperationPrePlanHook interface { + OnOperationPrePlan(ctx OperationPrePlanHookContext) error +} + +type OperationPostPlanHook interface { + OnOperationPostPlan(ctx OperationPostPlanHookContext) error +} + +type OperationExecuteLifecycleHook interface { + OperationPreExecuteHook + OperationPostExecuteHook +} + +type OperationPreExecuteHook interface { + OnOperationPreExecute(ctx OperationPreExecuteHookContext) error +} + +type OperationPostExecuteHook interface { + OnOperationPostExecute(ctx OperationPostExecuteHookContext) error +} + +// moduleHook is a wrapper around a hook that includes the module ID. +// this is used for tracability in case of hook execution errors. +type moduleHook[H any] struct { + ID string + Hook H +} + +// hookRegistry holds the list of hooks for each type. +type hookRegistry struct { + applicationStartHooks *utils.OrderedSet[moduleHook[ApplicationStartHook]] + applicationStopHooks *utils.OrderedSet[moduleHook[ApplicationStopHook]] + + graphQLServerStartHooks *utils.OrderedSet[moduleHook[GraphQLServerStartHook]] + graphQLServerStopHooks *utils.OrderedSet[moduleHook[GraphQLServerStopHook]] + + routerRequestHooks *utils.OrderedSet[moduleHook[RouterRequestHook]] + routerResponseHooks *utils.OrderedSet[moduleHook[RouterResponseHook]] + + subgraphRequestHooks *utils.OrderedSet[moduleHook[SubgraphRequestHook]] + subgraphResponseHooks *utils.OrderedSet[moduleHook[SubgraphResponseHook]] + + operationPreParseHooks *utils.OrderedSet[moduleHook[OperationPreParseHook]] + operationPostParseHooks *utils.OrderedSet[moduleHook[OperationPostParseHook]] + + operationPreNormalizeHooks *utils.OrderedSet[moduleHook[OperationPreNormalizeHook]] + operationPostNormalizeHooks *utils.OrderedSet[moduleHook[OperationPostNormalizeHook]] + + operationPreValidateHooks *utils.OrderedSet[moduleHook[OperationPreValidateHook]] + operationPostValidateHooks *utils.OrderedSet[moduleHook[OperationPostValidateHook]] + + operationPrePlanHooks *utils.OrderedSet[moduleHook[OperationPrePlanHook]] + operationPostPlanHooks *utils.OrderedSet[moduleHook[OperationPostPlanHook]] + + operationPreExecuteHooks *utils.OrderedSet[moduleHook[OperationPreExecuteHook]] + operationPostExecuteHooks *utils.OrderedSet[moduleHook[OperationPostExecuteHook]] +} + +// newHookRegistry initializes with empty sets. +func newHookRegistry() *hookRegistry { + return &hookRegistry{ + applicationStartHooks: utils.NewOrderedSet[moduleHook[ApplicationStartHook]](), + applicationStopHooks: utils.NewOrderedSet[moduleHook[ApplicationStopHook]](), + + graphQLServerStartHooks: utils.NewOrderedSet[moduleHook[GraphQLServerStartHook]](), + graphQLServerStopHooks: utils.NewOrderedSet[moduleHook[GraphQLServerStopHook]](), + + routerRequestHooks: utils.NewOrderedSet[moduleHook[RouterRequestHook]](), + routerResponseHooks: utils.NewOrderedSet[moduleHook[RouterResponseHook]](), + + subgraphRequestHooks: utils.NewOrderedSet[moduleHook[SubgraphRequestHook]](), + subgraphResponseHooks: utils.NewOrderedSet[moduleHook[SubgraphResponseHook]](), + + operationPreParseHooks: utils.NewOrderedSet[moduleHook[OperationPreParseHook]](), + operationPostParseHooks: utils.NewOrderedSet[moduleHook[OperationPostParseHook]](), + + operationPreNormalizeHooks: utils.NewOrderedSet[moduleHook[OperationPreNormalizeHook]](), + operationPostNormalizeHooks: utils.NewOrderedSet[moduleHook[OperationPostNormalizeHook]](), + + operationPreValidateHooks: utils.NewOrderedSet[moduleHook[OperationPreValidateHook]](), + operationPostValidateHooks: utils.NewOrderedSet[moduleHook[OperationPostValidateHook]](), + + operationPrePlanHooks: utils.NewOrderedSet[moduleHook[OperationPrePlanHook]](), + operationPostPlanHooks: utils.NewOrderedSet[moduleHook[OperationPostPlanHook]](), + + operationPreExecuteHooks: utils.NewOrderedSet[moduleHook[OperationPreExecuteHook]](), + operationPostExecuteHooks: utils.NewOrderedSet[moduleHook[OperationPostExecuteHook]](), + } +} + +// registerHook is a helper to add any hook type if implemented. +func registerHook[H comparable](inst any, set *utils.OrderedSet[moduleHook[H]], moduleID string) { + if h, ok := inst.(H); ok { + set.Add(moduleHook[H]{ + ID: moduleID, + Hook: h, + }) + } +} + +// AddApplicationLifecycle registers start/stop hooks. +func (hr *hookRegistry) AddApplicationLifecycle(inst any, moduleID string) { + registerHook(inst, hr.applicationStartHooks, moduleID) + registerHook(inst, hr.applicationStopHooks, moduleID) +} + +// AddGraphQLServerLifecycle registers GraphQL server start/stop hooks. +func (hr *hookRegistry) AddGraphQLServerLifecycle(inst any, moduleID string) { + registerHook(inst, hr.graphQLServerStartHooks, moduleID) + registerHook(inst, hr.graphQLServerStopHooks, moduleID) +} + +// AddRouterLifecycle registers router request/response hooks. +func (hr *hookRegistry) AddRouterLifecycle(inst any, moduleID string) { + registerHook(inst, hr.routerRequestHooks, moduleID) + registerHook(inst, hr.routerResponseHooks, moduleID) +} + +// AddSubgraphLifecycle registers subgraph request/response hooks. +func (hr *hookRegistry) AddSubgraphLifecycle(inst any, moduleID string) { + registerHook(inst, hr.subgraphRequestHooks, moduleID) + registerHook(inst, hr.subgraphResponseHooks, moduleID) +} + +// AddOperationLifecycle registers all operation lifecycle hooks. +func (hr *hookRegistry) AddOperationLifecycle(inst any, moduleID string) { + registerHook(inst, hr.operationPreParseHooks, moduleID) + registerHook(inst, hr.operationPostParseHooks, moduleID) + registerHook(inst, hr.operationPreNormalizeHooks, moduleID) + registerHook(inst, hr.operationPostNormalizeHooks, moduleID) + registerHook(inst, hr.operationPreValidateHooks, moduleID) + registerHook(inst, hr.operationPostValidateHooks, moduleID) + registerHook(inst, hr.operationPrePlanHooks, moduleID) + registerHook(inst, hr.operationPostPlanHooks, moduleID) + registerHook(inst, hr.operationPreExecuteHooks, moduleID) + registerHook(inst, hr.operationPostExecuteHooks, moduleID) +} + +// executeHooks executes the hooks in the order they were registered. +func executeHooks[H any](hooks []moduleHook[H], invoke func(H) error, hookName string, logger *zap.Logger) error { + logger.Debug("executing hooks", zap.String("hookName", hookName), zap.Int("hooks", len(hooks))) + for _, mk := range hooks { + if err := invoke(mk.Hook); err != nil { + return newModuleV1HookError(mk.ID, hookName, err) + } + } + return nil +} diff --git a/router/core/hooks_context.go b/router/core/hooks_context.go new file mode 100644 index 0000000000..012a9a41a9 --- /dev/null +++ b/router/core/hooks_context.go @@ -0,0 +1,75 @@ +package core + +// context interface for every hook +type ApplicationStartHookContext interface { + RequestContext +} + +type ApplicationStopHookContext interface { + RequestContext +} + +type GraphQLServerStartHookContext interface { + RequestContext +} + +type GraphQLServerStopHookContext interface { + RequestContext +} + +type RouterRequestHookContext interface { + RequestContext +} + +type RouterResponseHookContext interface { + RequestContext +} + +type SubgraphRequestHookContext interface { + RequestContext +} + +type SubgraphResponseHookContext interface { + RequestContext +} + +type OperationPreParseHookContext interface { + RequestContext +} + +type OperationPostParseHookContext interface { + RequestContext +} + +type OperationPreNormalizeHookContext interface { + RequestContext +} + +type OperationPostNormalizeHookContext interface { + RequestContext +} + +type OperationPreValidateHookContext interface { + RequestContext +} + +type OperationPostValidateHookContext interface { + RequestContext +} + +type OperationPrePlanHookContext interface { + RequestContext +} + +type OperationPostPlanHookContext interface { + RequestContext +} + +type OperationPreExecuteHookContext interface { + RequestContext +} + +type OperationPostExecuteHookContext interface { + RequestContext +} + diff --git a/router/core/hooks_test.go b/router/core/hooks_test.go new file mode 100644 index 0000000000..bc0e9e6b9e --- /dev/null +++ b/router/core/hooks_test.go @@ -0,0 +1,70 @@ +package core + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +type mockHook interface { + Do(ctx context.Context) error +} + +type mockHookImpl struct { + id string + shouldErr bool +} + +func (m *mockHookImpl) Do(ctx context.Context) error { + if m.shouldErr { + return fmt.Errorf("hook %s failed", m.id) + } + return nil +} + +func TestExecuteHooks(t *testing.T) { + ctx := context.Background() + + t.Run("all hooks succeed", func(t *testing.T) { + hooks := []moduleHook[mockHook]{ + {ID: "module1", Hook: &mockHookImpl{id: "hook1"}}, + {ID: "module2", Hook: &mockHookImpl{id: "hook2"}}, + } + + err := executeHooks(hooks, func(h mockHook) error { + return h.Do(ctx) + }, "MockHook", zaptest.NewLogger(t)) + + require.NoError(t, err) + }) + + t.Run("one hook fails", func(t *testing.T) { + hooks := []moduleHook[mockHook]{ + {ID: "module1", Hook: &mockHookImpl{id: "hook1"}}, + {ID: "moduleFail", Hook: &mockHookImpl{id: "hook2", shouldErr: true}}, + {ID: "module3", Hook: &mockHookImpl{id: "hook3"}}, + } + + err := executeHooks(hooks, func(h mockHook) error { + return h.Do(ctx) + }, "MockHook", zaptest.NewLogger(t)) + + require.Error(t, err) + + assert.Equal(t, err.Error(), "module moduleFail hook MockHook error: hook hook2 failed") + }) + + t.Run("empty hook list", func(t *testing.T) { + var hooks []moduleHook[mockHook] + + err := executeHooks(hooks, func(h mockHook) error { + return h.Do(ctx) + }, "MockHook", zaptest.NewLogger(t)) + + require.NoError(t, err) + }) +} diff --git a/router/core/modules_v1.go b/router/core/modules_v1.go new file mode 100644 index 0000000000..ecbbc82e8a --- /dev/null +++ b/router/core/modules_v1.go @@ -0,0 +1,232 @@ +package core + +import ( + "context" + "errors" + "fmt" + "math" + "sort" + "sync" + "time" + + "go.uber.org/zap" +) + +type moduleRegistry struct { + mu sync.RWMutex + modules map[string]ModuleV1Info +} + +// NewModuleRegistry returns an empty, thread-safe module registry. +// Call this in tests (and anywhere you need isolation) instead of using the global. +func newModuleRegistry() *moduleRegistry { + return &moduleRegistry{ + modules: make(map[string]ModuleV1Info), + } +} + +// defaultModuleRegistry is the package-level registry used by RegisterModuleV1. +// For unit tests you should use newModuleRegistry() to get a fresh instance and avoid shared state. +var defaultModuleRegistry = newModuleRegistry() + +type ModuleV1Info struct { + // ID is the unique identifier for a module, it must be unique across all modules. + ID string + // Priority decides the order of execution of the module. + // The smaller the number, the higher the priority, the earlier the module is executed. + // For example, a priority of 0 is the highest priority. + // Modules with the same priority are executed in the order they are registered. + // If Priority is nil, the module is considered to have the lowest priority. + Priority *int + // New creates a new instance of the module. + New func() ModuleV1 +} + +// ModuleV1Context provides context and utilities for module provisioning +// Maintains feature parity with the old module system +type ModuleV1Context struct { + context.Context + Module ModuleV1 + Logger *zap.Logger +} + +type ModuleV1 interface { + Module() ModuleV1Info + // Provisioner is called before the server starts + // It allows you to initialize your module e.g. create a database connection + Provision(ctx *ModuleV1Context) error + // Cleanup is called after the server stops + // It allows you to clean up your module e.g. close a database connection + Cleanup(ctx *ModuleV1Context) error +} + +// RegisterModuleV1 registers a new ModuleV1 instance. +// The registration order matters. Modules with the same priority +// are executed in the order they are registered. +// It panics if the module is already registered. +func RegisterModuleV1(instance ModuleV1) { + defaultModuleRegistry.registerModuleV1(instance) +} + +func (r *moduleRegistry) registerModuleV1(instance ModuleV1) { + m := instance.Module() + + if m.ID == "" { + panic("ModuleV1.ID is required") + } + if val := m.New(); val == nil { + panic(fmt.Sprintf("ModuleV1Info.New must return a non-nil module instance: %s", m.ID)) + } + + r.mu.Lock() + defer r.mu.Unlock() + + if _, ok := r.modules[m.ID]; ok { + panic(fmt.Sprintf("ModuleV1 already registered: %s", m.ID)) + } + r.modules[m.ID] = m +} + +// sortModulesV1 sorts the modules by priority, 0 is the highest priority, is the first to be executed. +// If two modules have the same priority, they are sorted by registration order. +// If a module has no priority, it is considered to have the lowest priority. +func sortModulesV1(modules []ModuleV1Info) []ModuleV1Info { + sort.Slice(modules, func(i, j int) bool { + var priorityI, priorityJ int = math.MaxInt, math.MaxInt + if modules[i].Priority != nil { + priorityI = *modules[i].Priority + } + if modules[j].Priority != nil { + priorityJ = *modules[j].Priority + } + + return priorityI < priorityJ + }) + return modules +} + +// getModulesV1 returns all registered modules sorted by priority +func (r *moduleRegistry) getModulesV1() []ModuleV1Info { + r.mu.RLock() + defer r.mu.RUnlock() + + modules := make([]ModuleV1Info, 0, len(r.modules)) + for _, m := range r.modules { + modules = append(modules, m) + } + return sortModulesV1(modules) +} + +// coreModuleHooks manages module initialization and hook registration. +type coreModuleHooks struct { + moduleInstances []ModuleV1 + hookRegistry *hookRegistry + logger *zap.Logger +} + +func newCoreModuleHooks(logger *zap.Logger) *coreModuleHooks { + return &coreModuleHooks{ + hookRegistry: newHookRegistry(), + logger: logger, + } +} + +// initCoreModuleHooks instantiates each module, provisions it, +// registers any implemented hooks, and saves the hook registry. +func (c *coreModuleHooks) initCoreModuleHooks(ctx context.Context, modules []ModuleV1Info) error { + hookRegistry := newHookRegistry() + var instances []ModuleV1 + + for _, info := range modules { + now := time.Now() + moduleInstance := info.New() + + moduleCtx := &ModuleV1Context{ + Context: ctx, + Module: moduleInstance, + Logger: c.logger.Named(info.ID), + } + + if err := moduleInstance.Provision(moduleCtx); err != nil { + return newModuleV1Error(info.ID, PhaseProvision, err) + } + + hookRegistry.AddApplicationLifecycle(moduleInstance, info.ID) + hookRegistry.AddGraphQLServerLifecycle(moduleInstance, info.ID) + hookRegistry.AddRouterLifecycle(moduleInstance, info.ID) + hookRegistry.AddSubgraphLifecycle(moduleInstance, info.ID) + hookRegistry.AddOperationLifecycle(moduleInstance, info.ID) + + c.logger.Info("Core Module System: Module registered", + zap.String("id", string(info.ID)), + zap.String("duration", time.Since(now).String()), + ) + + instances = append(instances, moduleInstance) + } + + c.hookRegistry = hookRegistry + c.moduleInstances = instances + + return nil +} + +func (c *coreModuleHooks) cleanupCoreModuleHooks(ctx context.Context) error { + var errs []error + for _, moduleInstance := range c.moduleInstances { + moduleCtx := &ModuleV1Context{ + Context: ctx, + Module: moduleInstance, + Logger: c.logger.Named(moduleInstance.Module().ID), + } + if err := moduleInstance.Cleanup(moduleCtx); err != nil { + errs = append(errs, newModuleV1Error(moduleInstance.Module().ID, PhaseCleanup, err)) + } + } + + return errors.Join(errs...) +} + +// ModuleV1Error provides structured error information for module operations +type ModuleV1Error struct { + ModuleID string + Phase phase + HookName *string + Err error +} + +type phase string + +const ( + PhaseProvision phase = "provision" + PhaseCleanup phase = "cleanup" + PhaseHook phase = "hook" +) + +func (e *ModuleV1Error) Error() string { + if e.Phase == PhaseHook && e.HookName != nil { + return fmt.Sprintf("module %s %s %s error: %v", e.ModuleID, e.Phase, *e.HookName, e.Err) + } + return fmt.Sprintf("module %s %s error: %v", e.ModuleID, e.Phase, e.Err) +} + +func (e *ModuleV1Error) Unwrap() error { + return e.Err +} + +func newModuleV1Error(moduleID string, phase phase, err error) error { + return &ModuleV1Error{ + ModuleID: moduleID, + Phase: phase, + Err: err, + } +} + +func newModuleV1HookError(moduleID, hookName string, err error) error { + return &ModuleV1Error{ + ModuleID: moduleID, + Phase: PhaseHook, + HookName: &hookName, + Err: err, + } +} diff --git a/router/core/modules_v1_test.go b/router/core/modules_v1_test.go new file mode 100644 index 0000000000..3b6cc61f2d --- /dev/null +++ b/router/core/modules_v1_test.go @@ -0,0 +1,409 @@ +package core + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/utils" + + "go.uber.org/zap/zaptest" +) + +type testModule1 struct{} + +func (m *testModule1) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "testModule1", + Priority: utils.Ptr(0), + New: func() ModuleV1 { + return &testModule1{} + }, + } +} +func (m *testModule1) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule1) Cleanup(ctx *ModuleV1Context) error { return errors.New("test error 1") } + +func (m *testModule1) OnApplicationStart(ctx ApplicationStartHookContext) error { return nil } + +type testModule2 struct{} + +func (m *testModule2) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "testModule2", + Priority: utils.Ptr(1), + New: func() ModuleV1 { + return &testModule2{} + }, + } +} +func (m *testModule2) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule2) Cleanup(ctx *ModuleV1Context) error { return nil } + +func (m *testModule2) OnApplicationStart(ctx ApplicationStartHookContext) error { return nil } +func (m *testModule2) OnApplicationStop(ctx ApplicationStopHookContext) error { return nil } + +type testModule3 struct{} + +func (m *testModule3) Module() ModuleV1Info { + return ModuleV1Info{ + Priority: utils.Ptr(1), + New: func() ModuleV1 { + return &testModule3{} + }, + } +} +func (m *testModule3) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule3) Cleanup(ctx *ModuleV1Context) error { return nil } + +func (m *testModule3) OnApplicationStart(ctx ApplicationStartHookContext) error { return nil } +func (m *testModule3) OnApplicationStop(ctx ApplicationStopHookContext) error { return nil } + +type testModule4 struct{} + +func (m *testModule4) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "testModule4", + Priority: utils.Ptr(1), + } +} +func (m *testModule4) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule4) Cleanup(ctx *ModuleV1Context) error { return errors.New("test error 4") } + +// interface guards +var _ ApplicationStartHook = (*testModule1)(nil) + +// registers the applicationStartHook only once +var _ ApplicationStartHook = (*testModule2)(nil) +var _ ApplicationLifecycleHook = (*testModule2)(nil) + +var _ ApplicationStartHook = (*testModule3)(nil) +var _ ApplicationStopHook = (*testModule3)(nil) + +func TestRegisterModuleV1(t *testing.T) { + t.Parallel() + + m1 := &testModule1{} + m2 := &testModule2{} + m3 := &testModule3{} + m4 := &testModule4{} + m5 := &testModule1{} + t.Run("success", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + testModuleRegistry.registerModuleV1(m1) + testModuleRegistry.registerModuleV1(m2) + + require.Equal(t, "testModule1", testModuleRegistry.getModulesV1()[0].ID) + require.Equal(t, "testModule2", testModuleRegistry.getModulesV1()[1].ID) + }) + + t.Run("panic if module id is empty", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + require.Panics(t, func() { + testModuleRegistry.registerModuleV1(m3) + }) + }) + + t.Run("panic if module new returns nil", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + require.Panics(t, func() { + testModuleRegistry.registerModuleV1(m4) + }) + }) + + t.Run("panic if module id is not unique", func(t *testing.T) { + testModuleRegistry := newModuleRegistry() + + require.Panics(t, func() { + testModuleRegistry.registerModuleV1(m1) + testModuleRegistry.registerModuleV1(m5) + }) + }) + + t.Run("module implements both ApplicationStartHook and ApplicationLifecycleHook should only register ApplicationStartHook once", func(t *testing.T) { + hookRegistry := newHookRegistry() + module := &testModule2{} + + hookRegistry.AddApplicationLifecycle(module, "testModule2") + + registerHook(module, hookRegistry.applicationStartHooks, "testModule2") + + require.Equal(t, 1, len(hookRegistry.applicationStartHooks.Values())) + + hooks := hookRegistry.applicationStartHooks.Values() + require.Equal(t, "testModule2", hooks[0].ID) + }) + + t.Run("different modules with same hook type should both be registered", func(t *testing.T) { + hookRegistry := newHookRegistry() + module1 := &testModule1{} + module2 := &testModule2{} + + registerHook(module1, hookRegistry.applicationStartHooks, "testModule1") + registerHook(module2, hookRegistry.applicationStartHooks, "testModule2") + + require.Equal(t, 2, len(hookRegistry.applicationStartHooks.Values())) + }) +} + +func TestSortModulesV1(t *testing.T) { + t.Parallel() + + module0 := ModuleV1Info{ + ID: "module0", + Priority: utils.Ptr(0), + } + + module1 := ModuleV1Info{ + ID: "module1", + Priority: utils.Ptr(1), + } + + module2 := ModuleV1Info{ + ID: "module2", + Priority: utils.Ptr(2), + } + + module3 := ModuleV1Info{ + ID: "module3", + Priority: utils.Ptr(0), + } + + moduleNilPriority := ModuleV1Info{ + ID: "moduleNil", + } + + t.Run("success", func(t *testing.T) { + modules := []ModuleV1Info{ + moduleNilPriority, + module2, + module0, + module1, + } + result := sortModulesV1(modules) + + expected := []ModuleV1Info{ + module0, + module1, + module2, + moduleNilPriority, + } + + require.EqualValues(t, expected, result) + }) + + t.Run("same priority", func(t *testing.T) { + modules := []ModuleV1Info{ + module3, + module0, + } + result := sortModulesV1(modules) + + expected := []ModuleV1Info{ + module3, + module0, + } + + require.EqualValues(t, expected, result) + }) + + t.Run("no modules not panic", func(t *testing.T) { + modules := []ModuleV1Info{} + require.Equal(t, []ModuleV1Info{}, sortModulesV1(modules)) + }) +} + +func TestInitModulesV1(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + modules := []ModuleV1Info{ + { + ID: "testModule1", + New: func() ModuleV1 { + return &testModule1{} + }, + }, + { + ID: "testModule3", + New: func() ModuleV1 { + return &testModule3{} + }, + }, + } + cm := newCoreModuleHooks(zaptest.NewLogger(t)) + err := cm.initCoreModuleHooks(context.Background(), modules) + require.NoError(t, err) + + require.Equal(t, 2, len(cm.hookRegistry.applicationStartHooks.Values())) + require.Equal(t, 1, len(cm.hookRegistry.applicationStopHooks.Values())) + }) +} + +func TestCleanupModulesV1(t *testing.T) { + t.Parallel() + + t.Run("all modules get a chance to cleanup", func(t *testing.T) { + modules := []ModuleV1Info{ + { + ID: "testModule1", + New: func() ModuleV1 { + return &testModule1{} + }, + }, + { + ID: "testModule2", + New: func() ModuleV1 { + return &testModule2{} + }, + }, + { + ID: "testModule4", + New: func() ModuleV1 { + return &testModule4{} + }, + }, + } + cm := newCoreModuleHooks(zaptest.NewLogger(t)) + err := cm.initCoreModuleHooks(context.Background(), modules) + require.NoError(t, err) + + err = cm.cleanupCoreModuleHooks(context.Background()) + require.Error(t, err) + require.Equal(t, "module testModule1 cleanup error: test error 1\nmodule testModule4 cleanup error: test error 4", err.Error()) + }) +} + +type failingProvisionModule struct{} + +func (m *failingProvisionModule) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "failing-provision-module", + New: func() ModuleV1 { + return &failingProvisionModule{} + }, + } +} + +func (m *failingProvisionModule) Provision(ctx *ModuleV1Context) error { + return errors.New("provision failed") +} + +func (m *failingProvisionModule) Cleanup(ctx *ModuleV1Context) error { + return nil +} + +func TestProvisionErrors(t *testing.T) { + t.Parallel() + + t.Run("provision failure stops initialization", func(t *testing.T) { + modules := []ModuleV1Info{ + { + ID: "failing-provision-module", + New: func() ModuleV1 { + return &failingProvisionModule{} + }, + }, + } + + cm := newCoreModuleHooks(zaptest.NewLogger(t)) + err := cm.initCoreModuleHooks(context.Background(), modules) + + require.Error(t, err) + var moduleErr *ModuleV1Error + require.True(t, errors.As(err, &moduleErr)) + require.Equal(t, "failing-provision-module", moduleErr.ModuleID) + require.Equal(t, PhaseProvision, moduleErr.Phase) + }) +} + +type failingHookModule struct{} + +func (m *failingHookModule) OnApplicationStart(ctx ApplicationStartHookContext) error { + return errors.New("hook execution failed") +} + +func TestHookExecution(t *testing.T) { + t.Parallel() + + t.Run("hook execution with error", func(t *testing.T) { + hooks := []moduleHook[ApplicationStartHook]{ + { + ID: "failing-module", + Hook: &failingHookModule{}, + }, + } + + err := executeHooks(hooks, func(h ApplicationStartHook) error { + return h.OnApplicationStart(nil) + }, "OnApplicationStart", zaptest.NewLogger(t)) + + require.Error(t, err) + var moduleErr *ModuleV1Error + require.True(t, errors.As(err, &moduleErr)) + require.Equal(t, "failing-module", moduleErr.ModuleID) + require.Equal(t, PhaseHook, moduleErr.Phase) + require.Equal(t, "OnApplicationStart", *moduleErr.HookName) + }) + + t.Run("hook execution success", func(t *testing.T) { + hooks := []moduleHook[ApplicationStartHook]{ + { + ID: "test-module", + Hook: &testModule1{}, + }, + } + + err := executeHooks(hooks, func(h ApplicationStartHook) error { + return h.OnApplicationStart(nil) + }, "OnApplicationStart", zaptest.NewLogger(t)) + + require.NoError(t, err) + }) +} + +func TestModuleV1Error(t *testing.T) { + t.Parallel() + + t.Run("provision error", func(t *testing.T) { + err := newModuleV1Error("test-module", PhaseProvision, errors.New("foo")) + moduleErr := err.(*ModuleV1Error) + + require.Equal(t, "test-module", moduleErr.ModuleID) + require.Equal(t, PhaseProvision, moduleErr.Phase) + require.Nil(t, moduleErr.HookName) + require.Equal(t, "module test-module provision error: foo", err.Error()) + }) + + t.Run("cleanup error", func(t *testing.T) { + err := newModuleV1Error("test-module", PhaseCleanup, errors.New("foo")) + moduleErr := err.(*ModuleV1Error) + + require.Equal(t, "test-module", moduleErr.ModuleID) + require.Equal(t, PhaseCleanup, moduleErr.Phase) + require.Nil(t, moduleErr.HookName) + require.Equal(t, "module test-module cleanup error: foo", err.Error()) + }) + + t.Run("hook error", func(t *testing.T) { + err := newModuleV1HookError("test-module", "OnApplicationStart", errors.New("foo")) + moduleErr := err.(*ModuleV1Error) + + require.Equal(t, "test-module", moduleErr.ModuleID) + require.Equal(t, PhaseHook, moduleErr.Phase) + require.NotNil(t, moduleErr.HookName) + require.Equal(t, "OnApplicationStart", *moduleErr.HookName) + require.Equal(t, "module test-module hook OnApplicationStart error: foo", err.Error()) + }) + + t.Run("error unwrapping", func(t *testing.T) { + originalErr := errors.New("original error") + err := newModuleV1Error("test-module", PhaseProvision, originalErr) + + require.Equal(t, originalErr, errors.Unwrap(err)) + }) +} diff --git a/router/core/router.go b/router/core/router.go index f3e3b9c04e..fb97e30d33 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -83,6 +83,7 @@ type ( Config httpServer *server modules []Module + moduleHooks *coreModuleHooks EngineStats statistics.EngineStatistics playgroundHandler func(http.Handler) http.Handler proxy ProxyFunc @@ -640,6 +641,26 @@ func (r *Router) initModules(ctx context.Context) error { return nil } +func (r *Router) initModulesV1(ctx context.Context) error { + moduleList := make([]ModuleV1Info, 0, len(r.customModulesV1)) + + for _, module := range r.customModulesV1 { + moduleList = append(moduleList, module.Module()) + } + + // Add globally registered modules from defaultModuleRegistry + globalModules := defaultModuleRegistry.getModulesV1() + moduleList = append(moduleList, globalModules...) + + if len(moduleList) == 0 { + return nil + } + + r.moduleHooks = newCoreModuleHooks(r.logger) + + return r.moduleHooks.initCoreModuleHooks(ctx, moduleList) +} + func (r *Router) BaseURL() string { return r.baseURL } @@ -915,6 +936,11 @@ func (r *Router) bootstrap(ctx context.Context) error { return fmt.Errorf("failed to init user modules: %w", err) } + // Initialize ModulesV1 system + if err := r.initModulesV1(ctx); err != nil { + return fmt.Errorf("failed to init modulesV1: %w", err) + } + if r.traceConfig.Enabled && len(r.tracePropagators) > 0 { r.compositePropagator = propagation.NewCompositeTextMapPropagator(r.tracePropagators...) @@ -1469,6 +1495,18 @@ func (r *Router) Shutdown(ctx context.Context) error { } }() + // Cleanup ModulesV1 + if r.moduleHooks != nil { + wg.Add(1) + go func() { + defer wg.Done() + + if subErr := r.moduleHooks.cleanupCoreModuleHooks(ctx); subErr != nil { + err.Append(fmt.Errorf("failed to cleanup ModulesV1: %w", subErr)) + } + }() + } + // Shutdown the CDN operation client and free up resources if r.persistedOperationClient != nil { r.persistedOperationClient.Close() @@ -1722,6 +1760,12 @@ func WithCustomModules(modules ...Module) Option { } } +func WithCustomModulesV1(modules ...ModuleV1) Option { + return func(r *Router) { + r.customModulesV1 = modules + } +} + func WithSubgraphTransportOptions(opts *SubgraphTransportOptions) Option { return func(r *Router) { r.subgraphTransportOptions = opts diff --git a/router/core/router_config.go b/router/core/router_config.go index 4173f45e50..363b64ca6e 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -101,6 +101,7 @@ type Config struct { registrationInfo *nodev1.RegistrationInfo securityConfiguration config.SecurityConfiguration customModules []Module + customModulesV1 []ModuleV1 engineExecutionConfiguration config.EngineExecutionConfiguration // should be removed once the users have migrated to the new overrides config overrideRoutingURLConfiguration config.OverrideRoutingURLConfiguration diff --git a/router/internal/utils/ordered_set.go b/router/internal/utils/ordered_set.go new file mode 100644 index 0000000000..aab895b0b5 --- /dev/null +++ b/router/internal/utils/ordered_set.go @@ -0,0 +1,61 @@ +package utils + +type OrderedSet[T comparable] struct { + elements []T + index map[T]struct{} +} + +// NewOrderedSet creates and returns a new OrderedSet. +func NewOrderedSet[T comparable]() *OrderedSet[T] { + return &OrderedSet[T]{ + elements: make([]T, 0), + index: make(map[T]struct{}), + } +} + +// Add inserts elem into the set if it's not already present. +func (s *OrderedSet[T]) Add(elem T) { + if _, exists := s.index[elem]; !exists { + s.index[elem] = struct{}{} + s.elements = append(s.elements, elem) + } +} + +// Remove deletes elem from the set if it exists, preserving order of other elements. +func (s *OrderedSet[T]) Remove(elem T) { + if _, exists := s.index[elem]; exists { + delete(s.index, elem) + // rebuild slice without the removed element + for i, v := range s.elements { + if v == elem { + s.elements = append(s.elements[:i], s.elements[i+1:]...) + break + } + } + } +} + +// Contains returns true if elem is in the set. +func (s *OrderedSet[T]) Contains(elem T) bool { + _, exists := s.index[elem] + return exists +} + +// Values returns a slice of elements in insertion order. +// The returned slice is a copy; modifying it won't affect the set. +func (s *OrderedSet[T]) Values() []T { + dup := make([]T, len(s.elements)) + copy(dup, s.elements) + return dup +} + +// Len returns the number of elements in the set. +func (s *OrderedSet[T]) Len() int { + return len(s.elements) +} + +// Clear removes all elements from the set. +func (s *OrderedSet[T]) Clear() { + s.elements = make([]T, 0) + s.index = make(map[T]struct{}) +} \ No newline at end of file diff --git a/router/internal/utils/ptrs.go b/router/internal/utils/ptrs.go new file mode 100644 index 0000000000..c954c0b883 --- /dev/null +++ b/router/internal/utils/ptrs.go @@ -0,0 +1,6 @@ +package utils + +func Ptr[T any](v T) *T { + return &v +} + From 5f2db98ea1ffb8f9e629791fa5e7bc74d95373ff Mon Sep 17 00:00:00 2001 From: Kaia Lang Date: Wed, 9 Jul 2025 14:00:49 -0700 Subject: [PATCH 2/3] remove global registration and use option-based module registration, add integration test --- .../custom_modules/database_module.go | 127 +++++++++++++++++- .../modules_v1/module_lifecycle_test.go | 63 +++++++++ router/core/hooks_test.go | 29 ++++ router/core/modules_v1.go | 84 ++++-------- router/core/modules_v1_test.go | 69 ---------- router/core/router.go | 14 +- 6 files changed, 246 insertions(+), 140 deletions(-) diff --git a/router-tests/modules_v1/custom_modules/database_module.go b/router-tests/modules_v1/custom_modules/database_module.go index b475be5b98..4287d29afa 100644 --- a/router-tests/modules_v1/custom_modules/database_module.go +++ b/router-tests/modules_v1/custom_modules/database_module.go @@ -1,28 +1,145 @@ package custom_modules import ( + "fmt" + "sync" + "time" + "github.com/wundergraph/cosmo/router/core" + "go.uber.org/zap" ) -type DatabaseModule struct {} +type DatabaseModule struct { + mu sync.RWMutex + connections map[string]*DatabaseConnection + metrics *DatabaseMetrics + isReady bool +} + +type DatabaseConnection struct { + ID string + CreatedAt time.Time + LastUsed time.Time + IsActive bool +} + +type DatabaseMetrics struct { + TotalConnections int + ActiveQueries int + TotalQueries int64 +} func (m *DatabaseModule) Module() core.ModuleV1Info { priority := 2 return core.ModuleV1Info{ - ID: "database_module", + ID: "database_module", Priority: &priority, New: func() core.ModuleV1 { - return &DatabaseModule{} + return m }, } } func (m *DatabaseModule) Provision(ctx *core.ModuleV1Context) error { - ctx.Logger.Info("Database module provisioned") + ctx.Logger.Info("Initializing database module...") + + m.mu.Lock() + defer m.mu.Unlock() + + m.connections = make(map[string]*DatabaseConnection) + m.metrics = &DatabaseMetrics{ + TotalConnections: 0, + ActiveQueries: 0, + TotalQueries: 0, + } + + for i := 0; i < 5; i++ { + connID := fmt.Sprintf("conn_%d", i) + conn := &DatabaseConnection{ + ID: connID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + IsActive: true, + } + m.connections[connID] = conn + m.metrics.TotalConnections++ + } + + m.isReady = true + + ctx.Logger.Info("Database module provisioned successfully", + zap.Int("connections", m.metrics.TotalConnections)) return nil } func (m *DatabaseModule) Cleanup(ctx *core.ModuleV1Context) error { - ctx.Logger.Info("Database module cleaned up") + ctx.Logger.Info("Shutting down database module...") + + m.mu.Lock() + defer m.mu.Unlock() + + for connID, conn := range m.connections { + conn.IsActive = false + ctx.Logger.Info("Closing database connection", zap.String("connection_id", connID)) + delete(m.connections, connID) + } + + m.metrics.TotalConnections = 0 + m.metrics.ActiveQueries = 0 + m.isReady = false + + ctx.Logger.Info("Database module cleaned up successfully") return nil } + +func (m *DatabaseModule) GetConnectionCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + if m.connections == nil { + return 0 + } + return len(m.connections) +} + +func (m *DatabaseModule) SimulateQuery(queryID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.isReady { + return fmt.Errorf("database module not ready") + } + + var selectedConn *DatabaseConnection + for _, conn := range m.connections { + if conn.IsActive { + selectedConn = conn + break + } + } + + if selectedConn == nil { + return fmt.Errorf("no available connections") + } + + selectedConn.LastUsed = time.Now() + m.metrics.ActiveQueries++ + m.metrics.TotalQueries++ + + go func() { + time.Sleep(10 * time.Millisecond) + m.mu.Lock() + m.metrics.ActiveQueries-- + m.mu.Unlock() + }() + + return nil +} + +func (m *DatabaseModule) GetMetrics() DatabaseMetrics { + m.mu.RLock() + defer m.mu.RUnlock() + if m.metrics == nil { + return DatabaseMetrics{} + } + return *m.metrics +} diff --git a/router-tests/modules_v1/module_lifecycle_test.go b/router-tests/modules_v1/module_lifecycle_test.go index dbe8b76abb..998dc01478 100644 --- a/router-tests/modules_v1/module_lifecycle_test.go +++ b/router-tests/modules_v1/module_lifecycle_test.go @@ -1,29 +1,91 @@ package modules_v1 import ( + "context" "encoding/json" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/modules_v1/custom_modules" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" + "go.uber.org/zap/zaptest" ) func TestModuleV1ProvisionAndCleanupLifecycle(t *testing.T) { t.Parallel() + t.Run("database module lifecycle", func(t *testing.T) { + t.Parallel() + + dbModule := &custom_modules.DatabaseModule{} + logger := zaptest.NewLogger(t) + + moduleCtx := &core.ModuleV1Context{ + Context: context.Background(), + Module: dbModule, + Logger: logger, + } + + err := dbModule.Provision(moduleCtx) + require.NoError(t, err) + + assert.Equal(t, 5, dbModule.GetConnectionCount()) + metrics := dbModule.GetMetrics() + assert.Equal(t, 5, metrics.TotalConnections) + + err = dbModule.SimulateQuery("manual_test_query") + require.NoError(t, err) + + updatedMetrics := dbModule.GetMetrics() + assert.Equal(t, int64(1), updatedMetrics.TotalQueries) + + err = dbModule.Cleanup(moduleCtx) + require.NoError(t, err) + + assert.Equal(t, 0, dbModule.GetConnectionCount()) + finalMetrics := dbModule.GetMetrics() + assert.Equal(t, 0, finalMetrics.TotalConnections) + assert.Equal(t, 0, finalMetrics.ActiveQueries) + + err = dbModule.SimulateQuery("post_cleanup_query") + assert.Error(t, err) + assert.Contains(t, err.Error(), "database module not ready") + }) + t.Run("no regression with the module system introduced", func(t *testing.T) { t.Parallel() dbModule := &custom_modules.DatabaseModule{} + assert.Equal(t, 0, dbModule.GetConnectionCount()) testenv.Run(t, &testenv.Config{ RouterOptions: []core.Option{ core.WithCustomModulesV1(dbModule), }, }, func(t *testing.T, xEnv *testenv.Environment) { + assert.Equal(t, 5, dbModule.GetConnectionCount(), "should have 5 connections after provision") + + metrics := dbModule.GetMetrics() + assert.Equal(t, 5, metrics.TotalConnections) + assert.Equal(t, 0, metrics.ActiveQueries) + assert.Equal(t, int64(0), metrics.TotalQueries) + + err := dbModule.SimulateQuery("test_query_1") + require.NoError(t, err) + + err = dbModule.SimulateQuery("test_query_2") + require.NoError(t, err) + + updatedMetrics := dbModule.GetMetrics() + assert.Equal(t, int64(2), updatedMetrics.TotalQueries, "should have recorded 2 queries") + + time.Sleep(20 * time.Millisecond) + finalMetrics := dbModule.GetMetrics() + assert.Equal(t, 0, finalMetrics.ActiveQueries, "all queries should be completed") + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ Query: `query MyQuery { employees { id } }`, OperationName: json.RawMessage(`"MyQuery"`), @@ -32,5 +94,6 @@ func TestModuleV1ProvisionAndCleanupLifecycle(t *testing.T) { assert.Equal(t, 200, res.Response.StatusCode) assert.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body) }) + }) } diff --git a/router/core/hooks_test.go b/router/core/hooks_test.go index bc0e9e6b9e..8a7e583ae9 100644 --- a/router/core/hooks_test.go +++ b/router/core/hooks_test.go @@ -68,3 +68,32 @@ func TestExecuteHooks(t *testing.T) { require.NoError(t, err) }) } + +func TestHookRegistration(t *testing.T) { + t.Parallel() + + t.Run("module implements both ApplicationStartHook and ApplicationLifecycleHook should only register ApplicationStartHook once", func(t *testing.T) { + hookRegistry := newHookRegistry() + module := &testModule2{} + + hookRegistry.AddApplicationLifecycle(module, "testModule2") + + registerHook(module, hookRegistry.applicationStartHooks, "testModule2") + + require.Equal(t, 1, len(hookRegistry.applicationStartHooks.Values())) + + hooks := hookRegistry.applicationStartHooks.Values() + require.Equal(t, "testModule2", hooks[0].ID) + }) + + t.Run("different modules with same hook type should both be registered", func(t *testing.T) { + hookRegistry := newHookRegistry() + module1 := &testModule1{} + module2 := &testModule2{} + + registerHook(module1, hookRegistry.applicationStartHooks, "testModule1") + registerHook(module2, hookRegistry.applicationStartHooks, "testModule2") + + require.Equal(t, 2, len(hookRegistry.applicationStartHooks.Values())) + }) +} diff --git a/router/core/modules_v1.go b/router/core/modules_v1.go index ecbbc82e8a..08c7a904d4 100644 --- a/router/core/modules_v1.go +++ b/router/core/modules_v1.go @@ -6,29 +6,11 @@ import ( "fmt" "math" "sort" - "sync" "time" "go.uber.org/zap" ) -type moduleRegistry struct { - mu sync.RWMutex - modules map[string]ModuleV1Info -} - -// NewModuleRegistry returns an empty, thread-safe module registry. -// Call this in tests (and anywhere you need isolation) instead of using the global. -func newModuleRegistry() *moduleRegistry { - return &moduleRegistry{ - modules: make(map[string]ModuleV1Info), - } -} - -// defaultModuleRegistry is the package-level registry used by RegisterModuleV1. -// For unit tests you should use newModuleRegistry() to get a fresh instance and avoid shared state. -var defaultModuleRegistry = newModuleRegistry() - type ModuleV1Info struct { // ID is the unique identifier for a module, it must be unique across all modules. ID string @@ -50,6 +32,33 @@ type ModuleV1Context struct { Logger *zap.Logger } +// ModuleV1 interface defines the contract for V1 modules. +// +// IMPORTANT: Concurrency Safety +// If your module stores state (fields, maps, slices, etc.), you MUST handle concurrency properly. +// The router is multi-threaded and your module methods may be called concurrently from different goroutines. +// +// Use synchronization primitives like sync.RWMutex for thread-safe access: +// +// type MyModule struct { +// mu sync.RWMutex +// data map[string]int +// } +// +// func (m *MyModule) SafeRead() int { +// m.mu.RLock() +// defer m.mu.RUnlock() +// return m.data["key"] +// } +// +// func (m *MyModule) SafeWrite(key string, value int) { +// m.mu.Lock() +// defer m.mu.Unlock() +// m.data[key] = value +// } +// +// Hook methods (if implemented) will be called concurrently during request processing. +// Provision() and Cleanup() are called once during router startup/shutdown and are inherently safe. type ModuleV1 interface { Module() ModuleV1Info // Provisioner is called before the server starts @@ -60,33 +69,6 @@ type ModuleV1 interface { Cleanup(ctx *ModuleV1Context) error } -// RegisterModuleV1 registers a new ModuleV1 instance. -// The registration order matters. Modules with the same priority -// are executed in the order they are registered. -// It panics if the module is already registered. -func RegisterModuleV1(instance ModuleV1) { - defaultModuleRegistry.registerModuleV1(instance) -} - -func (r *moduleRegistry) registerModuleV1(instance ModuleV1) { - m := instance.Module() - - if m.ID == "" { - panic("ModuleV1.ID is required") - } - if val := m.New(); val == nil { - panic(fmt.Sprintf("ModuleV1Info.New must return a non-nil module instance: %s", m.ID)) - } - - r.mu.Lock() - defer r.mu.Unlock() - - if _, ok := r.modules[m.ID]; ok { - panic(fmt.Sprintf("ModuleV1 already registered: %s", m.ID)) - } - r.modules[m.ID] = m -} - // sortModulesV1 sorts the modules by priority, 0 is the highest priority, is the first to be executed. // If two modules have the same priority, they are sorted by registration order. // If a module has no priority, it is considered to have the lowest priority. @@ -105,18 +87,6 @@ func sortModulesV1(modules []ModuleV1Info) []ModuleV1Info { return modules } -// getModulesV1 returns all registered modules sorted by priority -func (r *moduleRegistry) getModulesV1() []ModuleV1Info { - r.mu.RLock() - defer r.mu.RUnlock() - - modules := make([]ModuleV1Info, 0, len(r.modules)) - for _, m := range r.modules { - modules = append(modules, m) - } - return sortModulesV1(modules) -} - // coreModuleHooks manages module initialization and hook registration. type coreModuleHooks struct { moduleInstances []ModuleV1 diff --git a/router/core/modules_v1_test.go b/router/core/modules_v1_test.go index 3b6cc61f2d..03e451b047 100644 --- a/router/core/modules_v1_test.go +++ b/router/core/modules_v1_test.go @@ -81,75 +81,6 @@ var _ ApplicationLifecycleHook = (*testModule2)(nil) var _ ApplicationStartHook = (*testModule3)(nil) var _ ApplicationStopHook = (*testModule3)(nil) -func TestRegisterModuleV1(t *testing.T) { - t.Parallel() - - m1 := &testModule1{} - m2 := &testModule2{} - m3 := &testModule3{} - m4 := &testModule4{} - m5 := &testModule1{} - t.Run("success", func(t *testing.T) { - testModuleRegistry := newModuleRegistry() - - testModuleRegistry.registerModuleV1(m1) - testModuleRegistry.registerModuleV1(m2) - - require.Equal(t, "testModule1", testModuleRegistry.getModulesV1()[0].ID) - require.Equal(t, "testModule2", testModuleRegistry.getModulesV1()[1].ID) - }) - - t.Run("panic if module id is empty", func(t *testing.T) { - testModuleRegistry := newModuleRegistry() - - require.Panics(t, func() { - testModuleRegistry.registerModuleV1(m3) - }) - }) - - t.Run("panic if module new returns nil", func(t *testing.T) { - testModuleRegistry := newModuleRegistry() - - require.Panics(t, func() { - testModuleRegistry.registerModuleV1(m4) - }) - }) - - t.Run("panic if module id is not unique", func(t *testing.T) { - testModuleRegistry := newModuleRegistry() - - require.Panics(t, func() { - testModuleRegistry.registerModuleV1(m1) - testModuleRegistry.registerModuleV1(m5) - }) - }) - - t.Run("module implements both ApplicationStartHook and ApplicationLifecycleHook should only register ApplicationStartHook once", func(t *testing.T) { - hookRegistry := newHookRegistry() - module := &testModule2{} - - hookRegistry.AddApplicationLifecycle(module, "testModule2") - - registerHook(module, hookRegistry.applicationStartHooks, "testModule2") - - require.Equal(t, 1, len(hookRegistry.applicationStartHooks.Values())) - - hooks := hookRegistry.applicationStartHooks.Values() - require.Equal(t, "testModule2", hooks[0].ID) - }) - - t.Run("different modules with same hook type should both be registered", func(t *testing.T) { - hookRegistry := newHookRegistry() - module1 := &testModule1{} - module2 := &testModule2{} - - registerHook(module1, hookRegistry.applicationStartHooks, "testModule1") - registerHook(module2, hookRegistry.applicationStartHooks, "testModule2") - - require.Equal(t, 2, len(hookRegistry.applicationStartHooks.Values())) - }) -} - func TestSortModulesV1(t *testing.T) { t.Parallel() diff --git a/router/core/router.go b/router/core/router.go index fb97e30d33..7daf7b7f0e 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -642,20 +642,16 @@ func (r *Router) initModules(ctx context.Context) error { } func (r *Router) initModulesV1(ctx context.Context) error { - moduleList := make([]ModuleV1Info, 0, len(r.customModulesV1)) + if len(r.customModulesV1) == 0 { + return nil + } + moduleList := make([]ModuleV1Info, 0, len(r.customModulesV1)) for _, module := range r.customModulesV1 { moduleList = append(moduleList, module.Module()) } - // Add globally registered modules from defaultModuleRegistry - globalModules := defaultModuleRegistry.getModulesV1() - moduleList = append(moduleList, globalModules...) - - if len(moduleList) == 0 { - return nil - } - + moduleList = sortModulesV1(moduleList) r.moduleHooks = newCoreModuleHooks(r.logger) return r.moduleHooks.initCoreModuleHooks(ctx, moduleList) From d4d4db513a72351666317cddabd2cf942e482bc2 Mon Sep 17 00:00:00 2001 From: Kaia Lang Date: Wed, 9 Jul 2025 17:01:29 -0700 Subject: [PATCH 3/3] address coderabbitai comments --- .../modules_v1/custom_modules/database_module.go | 7 +++++++ router-tests/modules_v1/module_lifecycle_test.go | 6 +++--- router/core/modules_v1.go | 12 +++++------- router/core/router_config.go | 1 + 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/router-tests/modules_v1/custom_modules/database_module.go b/router-tests/modules_v1/custom_modules/database_module.go index 4287d29afa..12067bb60a 100644 --- a/router-tests/modules_v1/custom_modules/database_module.go +++ b/router-tests/modules_v1/custom_modules/database_module.go @@ -14,6 +14,7 @@ type DatabaseModule struct { connections map[string]*DatabaseConnection metrics *DatabaseMetrics isReady bool + wg sync.WaitGroup } type DatabaseConnection struct { @@ -35,6 +36,8 @@ func (m *DatabaseModule) Module() core.ModuleV1Info { ID: "database_module", Priority: &priority, New: func() core.ModuleV1 { + // For testing purposes, return the same instance so tests can inspect state. + // In production modules, you'd typically return &DatabaseModule{} for isolation. return m }, } @@ -75,6 +78,8 @@ func (m *DatabaseModule) Provision(ctx *core.ModuleV1Context) error { func (m *DatabaseModule) Cleanup(ctx *core.ModuleV1Context) error { ctx.Logger.Info("Shutting down database module...") + m.wg.Wait() + m.mu.Lock() defer m.mu.Unlock() @@ -125,7 +130,9 @@ func (m *DatabaseModule) SimulateQuery(queryID string) error { m.metrics.ActiveQueries++ m.metrics.TotalQueries++ + m.wg.Add(1) go func() { + defer m.wg.Done() time.Sleep(10 * time.Millisecond) m.mu.Lock() m.metrics.ActiveQueries-- diff --git a/router-tests/modules_v1/module_lifecycle_test.go b/router-tests/modules_v1/module_lifecycle_test.go index 998dc01478..67fe8cd528 100644 --- a/router-tests/modules_v1/module_lifecycle_test.go +++ b/router-tests/modules_v1/module_lifecycle_test.go @@ -82,9 +82,9 @@ func TestModuleV1ProvisionAndCleanupLifecycle(t *testing.T) { updatedMetrics := dbModule.GetMetrics() assert.Equal(t, int64(2), updatedMetrics.TotalQueries, "should have recorded 2 queries") - time.Sleep(20 * time.Millisecond) - finalMetrics := dbModule.GetMetrics() - assert.Equal(t, 0, finalMetrics.ActiveQueries, "all queries should be completed") + require.Eventually(t, func() bool { + return dbModule.GetMetrics().ActiveQueries == 0 + }, 1*time.Second, 1*time.Millisecond, "all queries should be completed") res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ Query: `query MyQuery { employees { id } }`, diff --git a/router/core/modules_v1.go b/router/core/modules_v1.go index 08c7a904d4..919244d509 100644 --- a/router/core/modules_v1.go +++ b/router/core/modules_v1.go @@ -104,7 +104,6 @@ func newCoreModuleHooks(logger *zap.Logger) *coreModuleHooks { // initCoreModuleHooks instantiates each module, provisions it, // registers any implemented hooks, and saves the hook registry. func (c *coreModuleHooks) initCoreModuleHooks(ctx context.Context, modules []ModuleV1Info) error { - hookRegistry := newHookRegistry() var instances []ModuleV1 for _, info := range modules { @@ -121,11 +120,11 @@ func (c *coreModuleHooks) initCoreModuleHooks(ctx context.Context, modules []Mod return newModuleV1Error(info.ID, PhaseProvision, err) } - hookRegistry.AddApplicationLifecycle(moduleInstance, info.ID) - hookRegistry.AddGraphQLServerLifecycle(moduleInstance, info.ID) - hookRegistry.AddRouterLifecycle(moduleInstance, info.ID) - hookRegistry.AddSubgraphLifecycle(moduleInstance, info.ID) - hookRegistry.AddOperationLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddApplicationLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddGraphQLServerLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddRouterLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddSubgraphLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddOperationLifecycle(moduleInstance, info.ID) c.logger.Info("Core Module System: Module registered", zap.String("id", string(info.ID)), @@ -135,7 +134,6 @@ func (c *coreModuleHooks) initCoreModuleHooks(ctx context.Context, modules []Mod instances = append(instances, moduleInstance) } - c.hookRegistry = hookRegistry c.moduleInstances = instances return nil diff --git a/router/core/router_config.go b/router/core/router_config.go index 363b64ca6e..17f4185ec9 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -203,6 +203,7 @@ func (c *Config) Usage() map[string]any { usage["prometheus"] = c.prometheusServer != nil usage["custom_modules"] = len(c.customModules) > 0 + usage["custom_modules_v1"] = len(c.customModulesV1) > 0 usage["header_rules"] = c.headerRules != nil && (c.headerRules.All != nil || len(c.headerRules.Subgraphs) > 0) usage["subgraph_transport_options"] = c.subgraphTransportOptions != nil usage["graphql_metrics"] = c.graphqlMetricsConfig != nil && c.graphqlMetricsConfig.Enabled