Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions router-tests/modules_v1/custom_modules/database_module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package custom_modules

import (
"fmt"
"sync"
"time"

"github.com/wundergraph/cosmo/router/core"
"go.uber.org/zap"
)

type DatabaseModule struct {
mu sync.RWMutex
connections map[string]*DatabaseConnection
metrics *DatabaseMetrics
isReady bool
wg sync.WaitGroup
}

type DatabaseConnection struct {
ID string
CreatedAt time.Time
LastUsed time.Time
IsActive bool
}

type DatabaseMetrics struct {
TotalConnections int
ActiveQueries int
TotalQueries int64
}

func (m *DatabaseModule) Module() core.ModuleV1Info {
priority := 2
return core.ModuleV1Info{
ID: "database_module",
Priority: &priority,
New: func() core.ModuleV1 {
// For testing purposes, return the same instance so tests can inspect state.
// In production modules, you'd typically return &DatabaseModule{} for isolation.
return m
},
}
}

func (m *DatabaseModule) Provision(ctx *core.ModuleV1Context) error {
ctx.Logger.Info("Initializing database module...")

m.mu.Lock()
defer m.mu.Unlock()

m.connections = make(map[string]*DatabaseConnection)
m.metrics = &DatabaseMetrics{
TotalConnections: 0,
ActiveQueries: 0,
TotalQueries: 0,
}

for i := 0; i < 5; i++ {
connID := fmt.Sprintf("conn_%d", i)
conn := &DatabaseConnection{
ID: connID,
CreatedAt: time.Now(),
LastUsed: time.Now(),
IsActive: true,
}
m.connections[connID] = conn
m.metrics.TotalConnections++
}

m.isReady = true

ctx.Logger.Info("Database module provisioned successfully",
zap.Int("connections", m.metrics.TotalConnections))
return nil
}

func (m *DatabaseModule) Cleanup(ctx *core.ModuleV1Context) error {
ctx.Logger.Info("Shutting down database module...")

m.wg.Wait()

m.mu.Lock()
defer m.mu.Unlock()

for connID, conn := range m.connections {
conn.IsActive = false
ctx.Logger.Info("Closing database connection", zap.String("connection_id", connID))
delete(m.connections, connID)
}

m.metrics.TotalConnections = 0
m.metrics.ActiveQueries = 0
m.isReady = false

ctx.Logger.Info("Database module cleaned up successfully")
return nil
}

func (m *DatabaseModule) GetConnectionCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
if m.connections == nil {
return 0
}
return len(m.connections)
}

func (m *DatabaseModule) SimulateQuery(queryID string) error {
m.mu.Lock()
defer m.mu.Unlock()

if !m.isReady {
return fmt.Errorf("database module not ready")
}

var selectedConn *DatabaseConnection
for _, conn := range m.connections {
if conn.IsActive {
selectedConn = conn
break
}
}

if selectedConn == nil {
return fmt.Errorf("no available connections")
}

selectedConn.LastUsed = time.Now()
m.metrics.ActiveQueries++
m.metrics.TotalQueries++

m.wg.Add(1)
go func() {
defer m.wg.Done()
time.Sleep(10 * time.Millisecond)
m.mu.Lock()
m.metrics.ActiveQueries--
m.mu.Unlock()
}()

return nil
}

func (m *DatabaseModule) GetMetrics() DatabaseMetrics {
m.mu.RLock()
defer m.mu.RUnlock()
if m.metrics == nil {
return DatabaseMetrics{}
}
return *m.metrics
}
99 changes: 99 additions & 0 deletions router-tests/modules_v1/module_lifecycle_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package modules_v1

import (
"context"
"encoding/json"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/modules_v1/custom_modules"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"go.uber.org/zap/zaptest"
)

func TestModuleV1ProvisionAndCleanupLifecycle(t *testing.T) {
t.Parallel()

t.Run("database module lifecycle", func(t *testing.T) {
t.Parallel()

dbModule := &custom_modules.DatabaseModule{}
logger := zaptest.NewLogger(t)

moduleCtx := &core.ModuleV1Context{
Context: context.Background(),
Module: dbModule,
Logger: logger,
}

err := dbModule.Provision(moduleCtx)
require.NoError(t, err)

assert.Equal(t, 5, dbModule.GetConnectionCount())
metrics := dbModule.GetMetrics()
assert.Equal(t, 5, metrics.TotalConnections)

err = dbModule.SimulateQuery("manual_test_query")
require.NoError(t, err)

updatedMetrics := dbModule.GetMetrics()
assert.Equal(t, int64(1), updatedMetrics.TotalQueries)

err = dbModule.Cleanup(moduleCtx)
require.NoError(t, err)

assert.Equal(t, 0, dbModule.GetConnectionCount())
finalMetrics := dbModule.GetMetrics()
assert.Equal(t, 0, finalMetrics.TotalConnections)
assert.Equal(t, 0, finalMetrics.ActiveQueries)

err = dbModule.SimulateQuery("post_cleanup_query")
assert.Error(t, err)
assert.Contains(t, err.Error(), "database module not ready")
})

t.Run("no regression with the module system introduced", func(t *testing.T) {
t.Parallel()

dbModule := &custom_modules.DatabaseModule{}
assert.Equal(t, 0, dbModule.GetConnectionCount())

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithCustomModulesV1(dbModule),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
assert.Equal(t, 5, dbModule.GetConnectionCount(), "should have 5 connections after provision")

metrics := dbModule.GetMetrics()
assert.Equal(t, 5, metrics.TotalConnections)
assert.Equal(t, 0, metrics.ActiveQueries)
assert.Equal(t, int64(0), metrics.TotalQueries)

err := dbModule.SimulateQuery("test_query_1")
require.NoError(t, err)

err = dbModule.SimulateQuery("test_query_2")
require.NoError(t, err)

updatedMetrics := dbModule.GetMetrics()
assert.Equal(t, int64(2), updatedMetrics.TotalQueries, "should have recorded 2 queries")

require.Eventually(t, func() bool {
return dbModule.GetMetrics().ActiveQueries == 0
}, 1*time.Second, 1*time.Millisecond, "all queries should be completed")

res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `query MyQuery { employees { id } }`,
OperationName: json.RawMessage(`"MyQuery"`),
})
require.NoError(t, err)
assert.Equal(t, 200, res.Response.StatusCode)
assert.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body)
})

})
}
Loading
Loading