diff --git a/router-tests/modules_v1/custom_modules/database_module.go b/router-tests/modules_v1/custom_modules/database_module.go new file mode 100644 index 0000000000..12067bb60a --- /dev/null +++ b/router-tests/modules_v1/custom_modules/database_module.go @@ -0,0 +1,152 @@ +package custom_modules + +import ( + "fmt" + "sync" + "time" + + "github.com/wundergraph/cosmo/router/core" + "go.uber.org/zap" +) + +type DatabaseModule struct { + mu sync.RWMutex + connections map[string]*DatabaseConnection + metrics *DatabaseMetrics + isReady bool + wg sync.WaitGroup +} + +type DatabaseConnection struct { + ID string + CreatedAt time.Time + LastUsed time.Time + IsActive bool +} + +type DatabaseMetrics struct { + TotalConnections int + ActiveQueries int + TotalQueries int64 +} + +func (m *DatabaseModule) Module() core.ModuleV1Info { + priority := 2 + return core.ModuleV1Info{ + ID: "database_module", + Priority: &priority, + New: func() core.ModuleV1 { + // For testing purposes, return the same instance so tests can inspect state. + // In production modules, you'd typically return &DatabaseModule{} for isolation. + return m + }, + } +} + +func (m *DatabaseModule) Provision(ctx *core.ModuleV1Context) error { + ctx.Logger.Info("Initializing database module...") + + m.mu.Lock() + defer m.mu.Unlock() + + m.connections = make(map[string]*DatabaseConnection) + m.metrics = &DatabaseMetrics{ + TotalConnections: 0, + ActiveQueries: 0, + TotalQueries: 0, + } + + for i := 0; i < 5; i++ { + connID := fmt.Sprintf("conn_%d", i) + conn := &DatabaseConnection{ + ID: connID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + IsActive: true, + } + m.connections[connID] = conn + m.metrics.TotalConnections++ + } + + m.isReady = true + + ctx.Logger.Info("Database module provisioned successfully", + zap.Int("connections", m.metrics.TotalConnections)) + return nil +} + +func (m *DatabaseModule) Cleanup(ctx *core.ModuleV1Context) error { + ctx.Logger.Info("Shutting down database module...") + + m.wg.Wait() + + m.mu.Lock() + defer m.mu.Unlock() + + for connID, conn := range m.connections { + conn.IsActive = false + ctx.Logger.Info("Closing database connection", zap.String("connection_id", connID)) + delete(m.connections, connID) + } + + m.metrics.TotalConnections = 0 + m.metrics.ActiveQueries = 0 + m.isReady = false + + ctx.Logger.Info("Database module cleaned up successfully") + return nil +} + +func (m *DatabaseModule) GetConnectionCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + if m.connections == nil { + return 0 + } + return len(m.connections) +} + +func (m *DatabaseModule) SimulateQuery(queryID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.isReady { + return fmt.Errorf("database module not ready") + } + + var selectedConn *DatabaseConnection + for _, conn := range m.connections { + if conn.IsActive { + selectedConn = conn + break + } + } + + if selectedConn == nil { + return fmt.Errorf("no available connections") + } + + selectedConn.LastUsed = time.Now() + m.metrics.ActiveQueries++ + m.metrics.TotalQueries++ + + m.wg.Add(1) + go func() { + defer m.wg.Done() + time.Sleep(10 * time.Millisecond) + m.mu.Lock() + m.metrics.ActiveQueries-- + m.mu.Unlock() + }() + + return nil +} + +func (m *DatabaseModule) GetMetrics() DatabaseMetrics { + m.mu.RLock() + defer m.mu.RUnlock() + if m.metrics == nil { + return DatabaseMetrics{} + } + return *m.metrics +} diff --git a/router-tests/modules_v1/module_lifecycle_test.go b/router-tests/modules_v1/module_lifecycle_test.go new file mode 100644 index 0000000000..67fe8cd528 --- /dev/null +++ b/router-tests/modules_v1/module_lifecycle_test.go @@ -0,0 +1,99 @@ +package modules_v1 + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/modules_v1/custom_modules" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "go.uber.org/zap/zaptest" +) + +func TestModuleV1ProvisionAndCleanupLifecycle(t *testing.T) { + t.Parallel() + + t.Run("database module lifecycle", func(t *testing.T) { + t.Parallel() + + dbModule := &custom_modules.DatabaseModule{} + logger := zaptest.NewLogger(t) + + moduleCtx := &core.ModuleV1Context{ + Context: context.Background(), + Module: dbModule, + Logger: logger, + } + + err := dbModule.Provision(moduleCtx) + require.NoError(t, err) + + assert.Equal(t, 5, dbModule.GetConnectionCount()) + metrics := dbModule.GetMetrics() + assert.Equal(t, 5, metrics.TotalConnections) + + err = dbModule.SimulateQuery("manual_test_query") + require.NoError(t, err) + + updatedMetrics := dbModule.GetMetrics() + assert.Equal(t, int64(1), updatedMetrics.TotalQueries) + + err = dbModule.Cleanup(moduleCtx) + require.NoError(t, err) + + assert.Equal(t, 0, dbModule.GetConnectionCount()) + finalMetrics := dbModule.GetMetrics() + assert.Equal(t, 0, finalMetrics.TotalConnections) + assert.Equal(t, 0, finalMetrics.ActiveQueries) + + err = dbModule.SimulateQuery("post_cleanup_query") + assert.Error(t, err) + assert.Contains(t, err.Error(), "database module not ready") + }) + + t.Run("no regression with the module system introduced", func(t *testing.T) { + t.Parallel() + + dbModule := &custom_modules.DatabaseModule{} + assert.Equal(t, 0, dbModule.GetConnectionCount()) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithCustomModulesV1(dbModule), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + assert.Equal(t, 5, dbModule.GetConnectionCount(), "should have 5 connections after provision") + + metrics := dbModule.GetMetrics() + assert.Equal(t, 5, metrics.TotalConnections) + assert.Equal(t, 0, metrics.ActiveQueries) + assert.Equal(t, int64(0), metrics.TotalQueries) + + err := dbModule.SimulateQuery("test_query_1") + require.NoError(t, err) + + err = dbModule.SimulateQuery("test_query_2") + require.NoError(t, err) + + updatedMetrics := dbModule.GetMetrics() + assert.Equal(t, int64(2), updatedMetrics.TotalQueries, "should have recorded 2 queries") + + require.Eventually(t, func() bool { + return dbModule.GetMetrics().ActiveQueries == 0 + }, 1*time.Second, 1*time.Millisecond, "all queries should be completed") + + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + assert.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body) + }) + + }) +} diff --git a/router/core/hooks.go b/router/core/hooks.go new file mode 100644 index 0000000000..2f781ed59c --- /dev/null +++ b/router/core/hooks.go @@ -0,0 +1,264 @@ +package core + +import ( + "github.com/wundergraph/cosmo/router/internal/utils" + "go.uber.org/zap" +) + +// Application Lifecycle Hooks +type ApplicationLifecycleHook interface { + ApplicationStartHook + ApplicationStopHook +} + +type ApplicationStartHook interface { + OnApplicationStart(ctx ApplicationStartHookContext) error +} + +type ApplicationStopHook interface { + OnApplicationStop(ctx ApplicationStopHookContext) error +} + +// GraphQL Server Lifecycle Hooks +type GraphQLServerLifecycleHook interface { + GraphQLServerStartHook + GraphQLServerStopHook +} + +type GraphQLServerStartHook interface { + OnGraphQLServerStart(ctx GraphQLServerStartHookContext) error +} + +type GraphQLServerStopHook interface { + OnGraphQLServerStop(ctx GraphQLServerStopHookContext) error +} + +// Router Lifecycle Hooks +type RouterRequestHook interface { + OnRouterRequest(ctx RouterRequestHookContext) error +} + +type RouterResponseHook interface { + OnRouterResponse(ctx RouterResponseHookContext) error +} + +type RouterLifecycleHook interface { + RouterRequestHook + RouterResponseHook +} + +// Subgraph Lifecycle Hooks +type SubgraphRequestHook interface { + OnSubgraphRequest(ctx SubgraphRequestHookContext) error +} + +type SubgraphResponseHook interface { + OnSubgraphResponse(ctx SubgraphResponseHookContext) error +} + +type SubgraphLifecycleHook interface { + SubgraphRequestHook + SubgraphResponseHook +} + +// Operation Lifecycle Hooks +type OperationLifecycleHook interface { + OperationParseLifecycleHook + OperationNormalizeLifecycleHook + OperationValidateLifecycleHook + OperationPlanLifecycleHook + OperationExecuteLifecycleHook +} + +type OperationParseLifecycleHook interface { + OperationPreParseHook + OperationPostParseHook +} + +type OperationPreParseHook interface { + OnOperationPreParse(ctx OperationPreParseHookContext) error +} + +type OperationPostParseHook interface { + OnOperationPostParse(ctx OperationPostParseHookContext) error +} + +type OperationNormalizeLifecycleHook interface { + OperationPreNormalizeHook + OperationPostNormalizeHook +} + +type OperationPreNormalizeHook interface { + OnOperationPreNormalize(ctx OperationPreNormalizeHookContext) error +} + +type OperationPostNormalizeHook interface { + OnOperationPostNormalize(ctx OperationPostNormalizeHookContext) error +} + +type OperationValidateLifecycleHook interface { + OperationPreValidateHook + OperationPostValidateHook +} + +type OperationPreValidateHook interface { + OnOperationPreValidate(ctx OperationPreValidateHookContext) error +} + +type OperationPostValidateHook interface { + OnOperationPostValidate(ctx OperationPostValidateHookContext) error +} + +type OperationPlanLifecycleHook interface { + OperationPrePlanHook + OperationPostPlanHook +} + +type OperationPrePlanHook interface { + OnOperationPrePlan(ctx OperationPrePlanHookContext) error +} + +type OperationPostPlanHook interface { + OnOperationPostPlan(ctx OperationPostPlanHookContext) error +} + +type OperationExecuteLifecycleHook interface { + OperationPreExecuteHook + OperationPostExecuteHook +} + +type OperationPreExecuteHook interface { + OnOperationPreExecute(ctx OperationPreExecuteHookContext) error +} + +type OperationPostExecuteHook interface { + OnOperationPostExecute(ctx OperationPostExecuteHookContext) error +} + +// moduleHook is a wrapper around a hook that includes the module ID. +// this is used for tracability in case of hook execution errors. +type moduleHook[H any] struct { + ID string + Hook H +} + +// hookRegistry holds the list of hooks for each type. +type hookRegistry struct { + applicationStartHooks *utils.OrderedSet[moduleHook[ApplicationStartHook]] + applicationStopHooks *utils.OrderedSet[moduleHook[ApplicationStopHook]] + + graphQLServerStartHooks *utils.OrderedSet[moduleHook[GraphQLServerStartHook]] + graphQLServerStopHooks *utils.OrderedSet[moduleHook[GraphQLServerStopHook]] + + routerRequestHooks *utils.OrderedSet[moduleHook[RouterRequestHook]] + routerResponseHooks *utils.OrderedSet[moduleHook[RouterResponseHook]] + + subgraphRequestHooks *utils.OrderedSet[moduleHook[SubgraphRequestHook]] + subgraphResponseHooks *utils.OrderedSet[moduleHook[SubgraphResponseHook]] + + operationPreParseHooks *utils.OrderedSet[moduleHook[OperationPreParseHook]] + operationPostParseHooks *utils.OrderedSet[moduleHook[OperationPostParseHook]] + + operationPreNormalizeHooks *utils.OrderedSet[moduleHook[OperationPreNormalizeHook]] + operationPostNormalizeHooks *utils.OrderedSet[moduleHook[OperationPostNormalizeHook]] + + operationPreValidateHooks *utils.OrderedSet[moduleHook[OperationPreValidateHook]] + operationPostValidateHooks *utils.OrderedSet[moduleHook[OperationPostValidateHook]] + + operationPrePlanHooks *utils.OrderedSet[moduleHook[OperationPrePlanHook]] + operationPostPlanHooks *utils.OrderedSet[moduleHook[OperationPostPlanHook]] + + operationPreExecuteHooks *utils.OrderedSet[moduleHook[OperationPreExecuteHook]] + operationPostExecuteHooks *utils.OrderedSet[moduleHook[OperationPostExecuteHook]] +} + +// newHookRegistry initializes with empty sets. +func newHookRegistry() *hookRegistry { + return &hookRegistry{ + applicationStartHooks: utils.NewOrderedSet[moduleHook[ApplicationStartHook]](), + applicationStopHooks: utils.NewOrderedSet[moduleHook[ApplicationStopHook]](), + + graphQLServerStartHooks: utils.NewOrderedSet[moduleHook[GraphQLServerStartHook]](), + graphQLServerStopHooks: utils.NewOrderedSet[moduleHook[GraphQLServerStopHook]](), + + routerRequestHooks: utils.NewOrderedSet[moduleHook[RouterRequestHook]](), + routerResponseHooks: utils.NewOrderedSet[moduleHook[RouterResponseHook]](), + + subgraphRequestHooks: utils.NewOrderedSet[moduleHook[SubgraphRequestHook]](), + subgraphResponseHooks: utils.NewOrderedSet[moduleHook[SubgraphResponseHook]](), + + operationPreParseHooks: utils.NewOrderedSet[moduleHook[OperationPreParseHook]](), + operationPostParseHooks: utils.NewOrderedSet[moduleHook[OperationPostParseHook]](), + + operationPreNormalizeHooks: utils.NewOrderedSet[moduleHook[OperationPreNormalizeHook]](), + operationPostNormalizeHooks: utils.NewOrderedSet[moduleHook[OperationPostNormalizeHook]](), + + operationPreValidateHooks: utils.NewOrderedSet[moduleHook[OperationPreValidateHook]](), + operationPostValidateHooks: utils.NewOrderedSet[moduleHook[OperationPostValidateHook]](), + + operationPrePlanHooks: utils.NewOrderedSet[moduleHook[OperationPrePlanHook]](), + operationPostPlanHooks: utils.NewOrderedSet[moduleHook[OperationPostPlanHook]](), + + operationPreExecuteHooks: utils.NewOrderedSet[moduleHook[OperationPreExecuteHook]](), + operationPostExecuteHooks: utils.NewOrderedSet[moduleHook[OperationPostExecuteHook]](), + } +} + +// registerHook is a helper to add any hook type if implemented. +func registerHook[H comparable](inst any, set *utils.OrderedSet[moduleHook[H]], moduleID string) { + if h, ok := inst.(H); ok { + set.Add(moduleHook[H]{ + ID: moduleID, + Hook: h, + }) + } +} + +// AddApplicationLifecycle registers start/stop hooks. +func (hr *hookRegistry) AddApplicationLifecycle(inst any, moduleID string) { + registerHook(inst, hr.applicationStartHooks, moduleID) + registerHook(inst, hr.applicationStopHooks, moduleID) +} + +// AddGraphQLServerLifecycle registers GraphQL server start/stop hooks. +func (hr *hookRegistry) AddGraphQLServerLifecycle(inst any, moduleID string) { + registerHook(inst, hr.graphQLServerStartHooks, moduleID) + registerHook(inst, hr.graphQLServerStopHooks, moduleID) +} + +// AddRouterLifecycle registers router request/response hooks. +func (hr *hookRegistry) AddRouterLifecycle(inst any, moduleID string) { + registerHook(inst, hr.routerRequestHooks, moduleID) + registerHook(inst, hr.routerResponseHooks, moduleID) +} + +// AddSubgraphLifecycle registers subgraph request/response hooks. +func (hr *hookRegistry) AddSubgraphLifecycle(inst any, moduleID string) { + registerHook(inst, hr.subgraphRequestHooks, moduleID) + registerHook(inst, hr.subgraphResponseHooks, moduleID) +} + +// AddOperationLifecycle registers all operation lifecycle hooks. +func (hr *hookRegistry) AddOperationLifecycle(inst any, moduleID string) { + registerHook(inst, hr.operationPreParseHooks, moduleID) + registerHook(inst, hr.operationPostParseHooks, moduleID) + registerHook(inst, hr.operationPreNormalizeHooks, moduleID) + registerHook(inst, hr.operationPostNormalizeHooks, moduleID) + registerHook(inst, hr.operationPreValidateHooks, moduleID) + registerHook(inst, hr.operationPostValidateHooks, moduleID) + registerHook(inst, hr.operationPrePlanHooks, moduleID) + registerHook(inst, hr.operationPostPlanHooks, moduleID) + registerHook(inst, hr.operationPreExecuteHooks, moduleID) + registerHook(inst, hr.operationPostExecuteHooks, moduleID) +} + +// executeHooks executes the hooks in the order they were registered. +func executeHooks[H any](hooks []moduleHook[H], invoke func(H) error, hookName string, logger *zap.Logger) error { + logger.Debug("executing hooks", zap.String("hookName", hookName), zap.Int("hooks", len(hooks))) + for _, mk := range hooks { + if err := invoke(mk.Hook); err != nil { + return newModuleV1HookError(mk.ID, hookName, err) + } + } + return nil +} diff --git a/router/core/hooks_context.go b/router/core/hooks_context.go new file mode 100644 index 0000000000..012a9a41a9 --- /dev/null +++ b/router/core/hooks_context.go @@ -0,0 +1,75 @@ +package core + +// context interface for every hook +type ApplicationStartHookContext interface { + RequestContext +} + +type ApplicationStopHookContext interface { + RequestContext +} + +type GraphQLServerStartHookContext interface { + RequestContext +} + +type GraphQLServerStopHookContext interface { + RequestContext +} + +type RouterRequestHookContext interface { + RequestContext +} + +type RouterResponseHookContext interface { + RequestContext +} + +type SubgraphRequestHookContext interface { + RequestContext +} + +type SubgraphResponseHookContext interface { + RequestContext +} + +type OperationPreParseHookContext interface { + RequestContext +} + +type OperationPostParseHookContext interface { + RequestContext +} + +type OperationPreNormalizeHookContext interface { + RequestContext +} + +type OperationPostNormalizeHookContext interface { + RequestContext +} + +type OperationPreValidateHookContext interface { + RequestContext +} + +type OperationPostValidateHookContext interface { + RequestContext +} + +type OperationPrePlanHookContext interface { + RequestContext +} + +type OperationPostPlanHookContext interface { + RequestContext +} + +type OperationPreExecuteHookContext interface { + RequestContext +} + +type OperationPostExecuteHookContext interface { + RequestContext +} + diff --git a/router/core/hooks_test.go b/router/core/hooks_test.go new file mode 100644 index 0000000000..8a7e583ae9 --- /dev/null +++ b/router/core/hooks_test.go @@ -0,0 +1,99 @@ +package core + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +type mockHook interface { + Do(ctx context.Context) error +} + +type mockHookImpl struct { + id string + shouldErr bool +} + +func (m *mockHookImpl) Do(ctx context.Context) error { + if m.shouldErr { + return fmt.Errorf("hook %s failed", m.id) + } + return nil +} + +func TestExecuteHooks(t *testing.T) { + ctx := context.Background() + + t.Run("all hooks succeed", func(t *testing.T) { + hooks := []moduleHook[mockHook]{ + {ID: "module1", Hook: &mockHookImpl{id: "hook1"}}, + {ID: "module2", Hook: &mockHookImpl{id: "hook2"}}, + } + + err := executeHooks(hooks, func(h mockHook) error { + return h.Do(ctx) + }, "MockHook", zaptest.NewLogger(t)) + + require.NoError(t, err) + }) + + t.Run("one hook fails", func(t *testing.T) { + hooks := []moduleHook[mockHook]{ + {ID: "module1", Hook: &mockHookImpl{id: "hook1"}}, + {ID: "moduleFail", Hook: &mockHookImpl{id: "hook2", shouldErr: true}}, + {ID: "module3", Hook: &mockHookImpl{id: "hook3"}}, + } + + err := executeHooks(hooks, func(h mockHook) error { + return h.Do(ctx) + }, "MockHook", zaptest.NewLogger(t)) + + require.Error(t, err) + + assert.Equal(t, err.Error(), "module moduleFail hook MockHook error: hook hook2 failed") + }) + + t.Run("empty hook list", func(t *testing.T) { + var hooks []moduleHook[mockHook] + + err := executeHooks(hooks, func(h mockHook) error { + return h.Do(ctx) + }, "MockHook", zaptest.NewLogger(t)) + + require.NoError(t, err) + }) +} + +func TestHookRegistration(t *testing.T) { + t.Parallel() + + t.Run("module implements both ApplicationStartHook and ApplicationLifecycleHook should only register ApplicationStartHook once", func(t *testing.T) { + hookRegistry := newHookRegistry() + module := &testModule2{} + + hookRegistry.AddApplicationLifecycle(module, "testModule2") + + registerHook(module, hookRegistry.applicationStartHooks, "testModule2") + + require.Equal(t, 1, len(hookRegistry.applicationStartHooks.Values())) + + hooks := hookRegistry.applicationStartHooks.Values() + require.Equal(t, "testModule2", hooks[0].ID) + }) + + t.Run("different modules with same hook type should both be registered", func(t *testing.T) { + hookRegistry := newHookRegistry() + module1 := &testModule1{} + module2 := &testModule2{} + + registerHook(module1, hookRegistry.applicationStartHooks, "testModule1") + registerHook(module2, hookRegistry.applicationStartHooks, "testModule2") + + require.Equal(t, 2, len(hookRegistry.applicationStartHooks.Values())) + }) +} diff --git a/router/core/modules_v1.go b/router/core/modules_v1.go new file mode 100644 index 0000000000..919244d509 --- /dev/null +++ b/router/core/modules_v1.go @@ -0,0 +1,200 @@ +package core + +import ( + "context" + "errors" + "fmt" + "math" + "sort" + "time" + + "go.uber.org/zap" +) + +type ModuleV1Info struct { + // ID is the unique identifier for a module, it must be unique across all modules. + ID string + // Priority decides the order of execution of the module. + // The smaller the number, the higher the priority, the earlier the module is executed. + // For example, a priority of 0 is the highest priority. + // Modules with the same priority are executed in the order they are registered. + // If Priority is nil, the module is considered to have the lowest priority. + Priority *int + // New creates a new instance of the module. + New func() ModuleV1 +} + +// ModuleV1Context provides context and utilities for module provisioning +// Maintains feature parity with the old module system +type ModuleV1Context struct { + context.Context + Module ModuleV1 + Logger *zap.Logger +} + +// ModuleV1 interface defines the contract for V1 modules. +// +// IMPORTANT: Concurrency Safety +// If your module stores state (fields, maps, slices, etc.), you MUST handle concurrency properly. +// The router is multi-threaded and your module methods may be called concurrently from different goroutines. +// +// Use synchronization primitives like sync.RWMutex for thread-safe access: +// +// type MyModule struct { +// mu sync.RWMutex +// data map[string]int +// } +// +// func (m *MyModule) SafeRead() int { +// m.mu.RLock() +// defer m.mu.RUnlock() +// return m.data["key"] +// } +// +// func (m *MyModule) SafeWrite(key string, value int) { +// m.mu.Lock() +// defer m.mu.Unlock() +// m.data[key] = value +// } +// +// Hook methods (if implemented) will be called concurrently during request processing. +// Provision() and Cleanup() are called once during router startup/shutdown and are inherently safe. +type ModuleV1 interface { + Module() ModuleV1Info + // Provisioner is called before the server starts + // It allows you to initialize your module e.g. create a database connection + Provision(ctx *ModuleV1Context) error + // Cleanup is called after the server stops + // It allows you to clean up your module e.g. close a database connection + Cleanup(ctx *ModuleV1Context) error +} + +// sortModulesV1 sorts the modules by priority, 0 is the highest priority, is the first to be executed. +// If two modules have the same priority, they are sorted by registration order. +// If a module has no priority, it is considered to have the lowest priority. +func sortModulesV1(modules []ModuleV1Info) []ModuleV1Info { + sort.Slice(modules, func(i, j int) bool { + var priorityI, priorityJ int = math.MaxInt, math.MaxInt + if modules[i].Priority != nil { + priorityI = *modules[i].Priority + } + if modules[j].Priority != nil { + priorityJ = *modules[j].Priority + } + + return priorityI < priorityJ + }) + return modules +} + +// coreModuleHooks manages module initialization and hook registration. +type coreModuleHooks struct { + moduleInstances []ModuleV1 + hookRegistry *hookRegistry + logger *zap.Logger +} + +func newCoreModuleHooks(logger *zap.Logger) *coreModuleHooks { + return &coreModuleHooks{ + hookRegistry: newHookRegistry(), + logger: logger, + } +} + +// initCoreModuleHooks instantiates each module, provisions it, +// registers any implemented hooks, and saves the hook registry. +func (c *coreModuleHooks) initCoreModuleHooks(ctx context.Context, modules []ModuleV1Info) error { + var instances []ModuleV1 + + for _, info := range modules { + now := time.Now() + moduleInstance := info.New() + + moduleCtx := &ModuleV1Context{ + Context: ctx, + Module: moduleInstance, + Logger: c.logger.Named(info.ID), + } + + if err := moduleInstance.Provision(moduleCtx); err != nil { + return newModuleV1Error(info.ID, PhaseProvision, err) + } + + c.hookRegistry.AddApplicationLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddGraphQLServerLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddRouterLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddSubgraphLifecycle(moduleInstance, info.ID) + c.hookRegistry.AddOperationLifecycle(moduleInstance, info.ID) + + c.logger.Info("Core Module System: Module registered", + zap.String("id", string(info.ID)), + zap.String("duration", time.Since(now).String()), + ) + + instances = append(instances, moduleInstance) + } + + c.moduleInstances = instances + + return nil +} + +func (c *coreModuleHooks) cleanupCoreModuleHooks(ctx context.Context) error { + var errs []error + for _, moduleInstance := range c.moduleInstances { + moduleCtx := &ModuleV1Context{ + Context: ctx, + Module: moduleInstance, + Logger: c.logger.Named(moduleInstance.Module().ID), + } + if err := moduleInstance.Cleanup(moduleCtx); err != nil { + errs = append(errs, newModuleV1Error(moduleInstance.Module().ID, PhaseCleanup, err)) + } + } + + return errors.Join(errs...) +} + +// ModuleV1Error provides structured error information for module operations +type ModuleV1Error struct { + ModuleID string + Phase phase + HookName *string + Err error +} + +type phase string + +const ( + PhaseProvision phase = "provision" + PhaseCleanup phase = "cleanup" + PhaseHook phase = "hook" +) + +func (e *ModuleV1Error) Error() string { + if e.Phase == PhaseHook && e.HookName != nil { + return fmt.Sprintf("module %s %s %s error: %v", e.ModuleID, e.Phase, *e.HookName, e.Err) + } + return fmt.Sprintf("module %s %s error: %v", e.ModuleID, e.Phase, e.Err) +} + +func (e *ModuleV1Error) Unwrap() error { + return e.Err +} + +func newModuleV1Error(moduleID string, phase phase, err error) error { + return &ModuleV1Error{ + ModuleID: moduleID, + Phase: phase, + Err: err, + } +} + +func newModuleV1HookError(moduleID, hookName string, err error) error { + return &ModuleV1Error{ + ModuleID: moduleID, + Phase: PhaseHook, + HookName: &hookName, + Err: err, + } +} diff --git a/router/core/modules_v1_test.go b/router/core/modules_v1_test.go new file mode 100644 index 0000000000..03e451b047 --- /dev/null +++ b/router/core/modules_v1_test.go @@ -0,0 +1,340 @@ +package core + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/utils" + + "go.uber.org/zap/zaptest" +) + +type testModule1 struct{} + +func (m *testModule1) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "testModule1", + Priority: utils.Ptr(0), + New: func() ModuleV1 { + return &testModule1{} + }, + } +} +func (m *testModule1) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule1) Cleanup(ctx *ModuleV1Context) error { return errors.New("test error 1") } + +func (m *testModule1) OnApplicationStart(ctx ApplicationStartHookContext) error { return nil } + +type testModule2 struct{} + +func (m *testModule2) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "testModule2", + Priority: utils.Ptr(1), + New: func() ModuleV1 { + return &testModule2{} + }, + } +} +func (m *testModule2) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule2) Cleanup(ctx *ModuleV1Context) error { return nil } + +func (m *testModule2) OnApplicationStart(ctx ApplicationStartHookContext) error { return nil } +func (m *testModule2) OnApplicationStop(ctx ApplicationStopHookContext) error { return nil } + +type testModule3 struct{} + +func (m *testModule3) Module() ModuleV1Info { + return ModuleV1Info{ + Priority: utils.Ptr(1), + New: func() ModuleV1 { + return &testModule3{} + }, + } +} +func (m *testModule3) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule3) Cleanup(ctx *ModuleV1Context) error { return nil } + +func (m *testModule3) OnApplicationStart(ctx ApplicationStartHookContext) error { return nil } +func (m *testModule3) OnApplicationStop(ctx ApplicationStopHookContext) error { return nil } + +type testModule4 struct{} + +func (m *testModule4) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "testModule4", + Priority: utils.Ptr(1), + } +} +func (m *testModule4) Provision(ctx *ModuleV1Context) error { return nil } +func (m *testModule4) Cleanup(ctx *ModuleV1Context) error { return errors.New("test error 4") } + +// interface guards +var _ ApplicationStartHook = (*testModule1)(nil) + +// registers the applicationStartHook only once +var _ ApplicationStartHook = (*testModule2)(nil) +var _ ApplicationLifecycleHook = (*testModule2)(nil) + +var _ ApplicationStartHook = (*testModule3)(nil) +var _ ApplicationStopHook = (*testModule3)(nil) + +func TestSortModulesV1(t *testing.T) { + t.Parallel() + + module0 := ModuleV1Info{ + ID: "module0", + Priority: utils.Ptr(0), + } + + module1 := ModuleV1Info{ + ID: "module1", + Priority: utils.Ptr(1), + } + + module2 := ModuleV1Info{ + ID: "module2", + Priority: utils.Ptr(2), + } + + module3 := ModuleV1Info{ + ID: "module3", + Priority: utils.Ptr(0), + } + + moduleNilPriority := ModuleV1Info{ + ID: "moduleNil", + } + + t.Run("success", func(t *testing.T) { + modules := []ModuleV1Info{ + moduleNilPriority, + module2, + module0, + module1, + } + result := sortModulesV1(modules) + + expected := []ModuleV1Info{ + module0, + module1, + module2, + moduleNilPriority, + } + + require.EqualValues(t, expected, result) + }) + + t.Run("same priority", func(t *testing.T) { + modules := []ModuleV1Info{ + module3, + module0, + } + result := sortModulesV1(modules) + + expected := []ModuleV1Info{ + module3, + module0, + } + + require.EqualValues(t, expected, result) + }) + + t.Run("no modules not panic", func(t *testing.T) { + modules := []ModuleV1Info{} + require.Equal(t, []ModuleV1Info{}, sortModulesV1(modules)) + }) +} + +func TestInitModulesV1(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + modules := []ModuleV1Info{ + { + ID: "testModule1", + New: func() ModuleV1 { + return &testModule1{} + }, + }, + { + ID: "testModule3", + New: func() ModuleV1 { + return &testModule3{} + }, + }, + } + cm := newCoreModuleHooks(zaptest.NewLogger(t)) + err := cm.initCoreModuleHooks(context.Background(), modules) + require.NoError(t, err) + + require.Equal(t, 2, len(cm.hookRegistry.applicationStartHooks.Values())) + require.Equal(t, 1, len(cm.hookRegistry.applicationStopHooks.Values())) + }) +} + +func TestCleanupModulesV1(t *testing.T) { + t.Parallel() + + t.Run("all modules get a chance to cleanup", func(t *testing.T) { + modules := []ModuleV1Info{ + { + ID: "testModule1", + New: func() ModuleV1 { + return &testModule1{} + }, + }, + { + ID: "testModule2", + New: func() ModuleV1 { + return &testModule2{} + }, + }, + { + ID: "testModule4", + New: func() ModuleV1 { + return &testModule4{} + }, + }, + } + cm := newCoreModuleHooks(zaptest.NewLogger(t)) + err := cm.initCoreModuleHooks(context.Background(), modules) + require.NoError(t, err) + + err = cm.cleanupCoreModuleHooks(context.Background()) + require.Error(t, err) + require.Equal(t, "module testModule1 cleanup error: test error 1\nmodule testModule4 cleanup error: test error 4", err.Error()) + }) +} + +type failingProvisionModule struct{} + +func (m *failingProvisionModule) Module() ModuleV1Info { + return ModuleV1Info{ + ID: "failing-provision-module", + New: func() ModuleV1 { + return &failingProvisionModule{} + }, + } +} + +func (m *failingProvisionModule) Provision(ctx *ModuleV1Context) error { + return errors.New("provision failed") +} + +func (m *failingProvisionModule) Cleanup(ctx *ModuleV1Context) error { + return nil +} + +func TestProvisionErrors(t *testing.T) { + t.Parallel() + + t.Run("provision failure stops initialization", func(t *testing.T) { + modules := []ModuleV1Info{ + { + ID: "failing-provision-module", + New: func() ModuleV1 { + return &failingProvisionModule{} + }, + }, + } + + cm := newCoreModuleHooks(zaptest.NewLogger(t)) + err := cm.initCoreModuleHooks(context.Background(), modules) + + require.Error(t, err) + var moduleErr *ModuleV1Error + require.True(t, errors.As(err, &moduleErr)) + require.Equal(t, "failing-provision-module", moduleErr.ModuleID) + require.Equal(t, PhaseProvision, moduleErr.Phase) + }) +} + +type failingHookModule struct{} + +func (m *failingHookModule) OnApplicationStart(ctx ApplicationStartHookContext) error { + return errors.New("hook execution failed") +} + +func TestHookExecution(t *testing.T) { + t.Parallel() + + t.Run("hook execution with error", func(t *testing.T) { + hooks := []moduleHook[ApplicationStartHook]{ + { + ID: "failing-module", + Hook: &failingHookModule{}, + }, + } + + err := executeHooks(hooks, func(h ApplicationStartHook) error { + return h.OnApplicationStart(nil) + }, "OnApplicationStart", zaptest.NewLogger(t)) + + require.Error(t, err) + var moduleErr *ModuleV1Error + require.True(t, errors.As(err, &moduleErr)) + require.Equal(t, "failing-module", moduleErr.ModuleID) + require.Equal(t, PhaseHook, moduleErr.Phase) + require.Equal(t, "OnApplicationStart", *moduleErr.HookName) + }) + + t.Run("hook execution success", func(t *testing.T) { + hooks := []moduleHook[ApplicationStartHook]{ + { + ID: "test-module", + Hook: &testModule1{}, + }, + } + + err := executeHooks(hooks, func(h ApplicationStartHook) error { + return h.OnApplicationStart(nil) + }, "OnApplicationStart", zaptest.NewLogger(t)) + + require.NoError(t, err) + }) +} + +func TestModuleV1Error(t *testing.T) { + t.Parallel() + + t.Run("provision error", func(t *testing.T) { + err := newModuleV1Error("test-module", PhaseProvision, errors.New("foo")) + moduleErr := err.(*ModuleV1Error) + + require.Equal(t, "test-module", moduleErr.ModuleID) + require.Equal(t, PhaseProvision, moduleErr.Phase) + require.Nil(t, moduleErr.HookName) + require.Equal(t, "module test-module provision error: foo", err.Error()) + }) + + t.Run("cleanup error", func(t *testing.T) { + err := newModuleV1Error("test-module", PhaseCleanup, errors.New("foo")) + moduleErr := err.(*ModuleV1Error) + + require.Equal(t, "test-module", moduleErr.ModuleID) + require.Equal(t, PhaseCleanup, moduleErr.Phase) + require.Nil(t, moduleErr.HookName) + require.Equal(t, "module test-module cleanup error: foo", err.Error()) + }) + + t.Run("hook error", func(t *testing.T) { + err := newModuleV1HookError("test-module", "OnApplicationStart", errors.New("foo")) + moduleErr := err.(*ModuleV1Error) + + require.Equal(t, "test-module", moduleErr.ModuleID) + require.Equal(t, PhaseHook, moduleErr.Phase) + require.NotNil(t, moduleErr.HookName) + require.Equal(t, "OnApplicationStart", *moduleErr.HookName) + require.Equal(t, "module test-module hook OnApplicationStart error: foo", err.Error()) + }) + + t.Run("error unwrapping", func(t *testing.T) { + originalErr := errors.New("original error") + err := newModuleV1Error("test-module", PhaseProvision, originalErr) + + require.Equal(t, originalErr, errors.Unwrap(err)) + }) +} diff --git a/router/core/router.go b/router/core/router.go index f3e3b9c04e..7daf7b7f0e 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -83,6 +83,7 @@ type ( Config httpServer *server modules []Module + moduleHooks *coreModuleHooks EngineStats statistics.EngineStatistics playgroundHandler func(http.Handler) http.Handler proxy ProxyFunc @@ -640,6 +641,22 @@ func (r *Router) initModules(ctx context.Context) error { return nil } +func (r *Router) initModulesV1(ctx context.Context) error { + if len(r.customModulesV1) == 0 { + return nil + } + + moduleList := make([]ModuleV1Info, 0, len(r.customModulesV1)) + for _, module := range r.customModulesV1 { + moduleList = append(moduleList, module.Module()) + } + + moduleList = sortModulesV1(moduleList) + r.moduleHooks = newCoreModuleHooks(r.logger) + + return r.moduleHooks.initCoreModuleHooks(ctx, moduleList) +} + func (r *Router) BaseURL() string { return r.baseURL } @@ -915,6 +932,11 @@ func (r *Router) bootstrap(ctx context.Context) error { return fmt.Errorf("failed to init user modules: %w", err) } + // Initialize ModulesV1 system + if err := r.initModulesV1(ctx); err != nil { + return fmt.Errorf("failed to init modulesV1: %w", err) + } + if r.traceConfig.Enabled && len(r.tracePropagators) > 0 { r.compositePropagator = propagation.NewCompositeTextMapPropagator(r.tracePropagators...) @@ -1469,6 +1491,18 @@ func (r *Router) Shutdown(ctx context.Context) error { } }() + // Cleanup ModulesV1 + if r.moduleHooks != nil { + wg.Add(1) + go func() { + defer wg.Done() + + if subErr := r.moduleHooks.cleanupCoreModuleHooks(ctx); subErr != nil { + err.Append(fmt.Errorf("failed to cleanup ModulesV1: %w", subErr)) + } + }() + } + // Shutdown the CDN operation client and free up resources if r.persistedOperationClient != nil { r.persistedOperationClient.Close() @@ -1722,6 +1756,12 @@ func WithCustomModules(modules ...Module) Option { } } +func WithCustomModulesV1(modules ...ModuleV1) Option { + return func(r *Router) { + r.customModulesV1 = modules + } +} + func WithSubgraphTransportOptions(opts *SubgraphTransportOptions) Option { return func(r *Router) { r.subgraphTransportOptions = opts diff --git a/router/core/router_config.go b/router/core/router_config.go index 4173f45e50..17f4185ec9 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -101,6 +101,7 @@ type Config struct { registrationInfo *nodev1.RegistrationInfo securityConfiguration config.SecurityConfiguration customModules []Module + customModulesV1 []ModuleV1 engineExecutionConfiguration config.EngineExecutionConfiguration // should be removed once the users have migrated to the new overrides config overrideRoutingURLConfiguration config.OverrideRoutingURLConfiguration @@ -202,6 +203,7 @@ func (c *Config) Usage() map[string]any { usage["prometheus"] = c.prometheusServer != nil usage["custom_modules"] = len(c.customModules) > 0 + usage["custom_modules_v1"] = len(c.customModulesV1) > 0 usage["header_rules"] = c.headerRules != nil && (c.headerRules.All != nil || len(c.headerRules.Subgraphs) > 0) usage["subgraph_transport_options"] = c.subgraphTransportOptions != nil usage["graphql_metrics"] = c.graphqlMetricsConfig != nil && c.graphqlMetricsConfig.Enabled diff --git a/router/internal/utils/ordered_set.go b/router/internal/utils/ordered_set.go new file mode 100644 index 0000000000..aab895b0b5 --- /dev/null +++ b/router/internal/utils/ordered_set.go @@ -0,0 +1,61 @@ +package utils + +type OrderedSet[T comparable] struct { + elements []T + index map[T]struct{} +} + +// NewOrderedSet creates and returns a new OrderedSet. +func NewOrderedSet[T comparable]() *OrderedSet[T] { + return &OrderedSet[T]{ + elements: make([]T, 0), + index: make(map[T]struct{}), + } +} + +// Add inserts elem into the set if it's not already present. +func (s *OrderedSet[T]) Add(elem T) { + if _, exists := s.index[elem]; !exists { + s.index[elem] = struct{}{} + s.elements = append(s.elements, elem) + } +} + +// Remove deletes elem from the set if it exists, preserving order of other elements. +func (s *OrderedSet[T]) Remove(elem T) { + if _, exists := s.index[elem]; exists { + delete(s.index, elem) + // rebuild slice without the removed element + for i, v := range s.elements { + if v == elem { + s.elements = append(s.elements[:i], s.elements[i+1:]...) + break + } + } + } +} + +// Contains returns true if elem is in the set. +func (s *OrderedSet[T]) Contains(elem T) bool { + _, exists := s.index[elem] + return exists +} + +// Values returns a slice of elements in insertion order. +// The returned slice is a copy; modifying it won't affect the set. +func (s *OrderedSet[T]) Values() []T { + dup := make([]T, len(s.elements)) + copy(dup, s.elements) + return dup +} + +// Len returns the number of elements in the set. +func (s *OrderedSet[T]) Len() int { + return len(s.elements) +} + +// Clear removes all elements from the set. +func (s *OrderedSet[T]) Clear() { + s.elements = make([]T, 0) + s.index = make(map[T]struct{}) +} \ No newline at end of file diff --git a/router/internal/utils/ptrs.go b/router/internal/utils/ptrs.go new file mode 100644 index 0000000000..c954c0b883 --- /dev/null +++ b/router/internal/utils/ptrs.go @@ -0,0 +1,6 @@ +package utils + +func Ptr[T any](v T) *T { + return &v +} +