diff --git a/router-tests/modules/verify-operation-context-values/module.go b/router-tests/modules/verify-operation-context-values/module.go new file mode 100644 index 0000000000..b044bc71c5 --- /dev/null +++ b/router-tests/modules/verify-operation-context-values/module.go @@ -0,0 +1,84 @@ +package verify_operation_context_values + +import ( + "net/http" + + "github.com/wundergraph/astjson" + "github.com/wundergraph/cosmo/router/core" + "go.uber.org/zap" +) + +const myModuleID = "verifyOperationContextValues" + +// CapturedOperationValues holds the captured values from operation context +type CapturedOperationValues struct { + Name string + Type string + Hash uint64 + Content string + Variables *astjson.Value + // Store the raw variables as string for easier testing + VariablesJSON string + ClientInfo core.ClientInfo +} + +// VerifyOperationContextValuesModule captures operation context values for verification +type VerifyOperationContextValuesModule struct { + ResultsChan chan CapturedOperationValues + Logger *zap.Logger +} + +func (m *VerifyOperationContextValuesModule) Provision(ctx *core.ModuleContext) error { + m.Logger = ctx.Logger + if m.ResultsChan == nil { + m.ResultsChan = make(chan CapturedOperationValues, 1) + } + return nil +} + +func (m *VerifyOperationContextValuesModule) Middleware(ctx core.RequestContext, next http.Handler) { + operation := ctx.Operation() + + // Capture all the operation context values + captured := CapturedOperationValues{ + Name: operation.Name(), + Type: operation.Type(), + Hash: operation.Hash(), + Content: operation.Content(), + Variables: operation.Variables(), + ClientInfo: operation.ClientInfo(), + } + + // Convert variables to JSON string for easier testing + captured.VariablesJSON = "{}" + if captured.Variables != nil { + variablesBytes := captured.Variables.MarshalTo(nil) + captured.VariablesJSON = string(variablesBytes) + } + // Send the captured values to the test + select { + case m.ResultsChan <- captured: + default: + // Channel is full, skip + } + + // Call the next handler in the chain + next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) +} + +func (m *VerifyOperationContextValuesModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &VerifyOperationContextValuesModule{ + ResultsChan: make(chan CapturedOperationValues, 1), + } + }, + } +} + +// Interface guard +var ( + _ core.RouterMiddlewareHandler = (*VerifyOperationContextValuesModule)(nil) +) diff --git a/router-tests/modules/verify_operation_context_values_test.go b/router-tests/modules/verify_operation_context_values_test.go new file mode 100644 index 0000000000..a7ebd20be1 --- /dev/null +++ b/router-tests/modules/verify_operation_context_values_test.go @@ -0,0 +1,195 @@ +package module_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + verifyModule "github.com/wundergraph/cosmo/router-tests/modules/verify-operation-context-values" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "go.uber.org/zap/zapcore" +) + +func TestVerifyOperationContextValues(t *testing.T) { + t.Run("verifies all operation context values are set correctly", func(t *testing.T) { + t.Parallel() + + resultsChan := make(chan verifyModule.CapturedOperationValues, 1) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]any{ + "verifyOperationContextValues": verifyModule.VerifyOperationContextValuesModule{ + ResultsChan: resultsChan, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&verifyModule.VerifyOperationContextValuesModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Send a GraphQL query with variables that are actually used + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query GetEmployee($empId: Int!) { + employee(id: $empId) { + id + details { + forename + surname + } + tag + } + }`, + Variables: json.RawMessage(`{"empId": 1}`), + OperationName: json.RawMessage(`"GetEmployee"`), + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + + // Wait for the module to capture the operation context values + testenv.AwaitChannelWithT(t, 10*time.Second, resultsChan, func(t *testing.T, captured verifyModule.CapturedOperationValues) { + // Verify operation name + assert.Equal(t, "GetEmployee", captured.Name, "Operation name should be set correctly") + + // Verify operation type + assert.Equal(t, "query", captured.Type, "Operation type should be 'query'") + + // Verify operation hash is set (non-zero) + assert.NotZero(t, captured.Hash, "Operation hash should be set") + + // Verify operation content is set and contains the normalized query + assert.Equal(t, captured.Content, "query GetEmployee($a: Int!){employee(id: $a){id details {forename surname} tag}}", "Operation content should be set") + + // Verify Variables() method returns the correct variables + assert.NotNil(t, captured.Variables, "Variables should not be nil") + + // Verify the variables JSON contains the expected values + assert.JSONEq(t, `{"empId": 1}`, captured.VariablesJSON, "Variables JSON should match the sent variables") + + // Verify we can access individual variables + empIdVar := captured.Variables.Get("empId") + assert.NotNil(t, empIdVar, "Should be able to access 'empId' variable") + assert.Equal(t, 1, empIdVar.GetInt(), "empId variable should be 1") + + // Verify client info is populated (at least the basic structure) + assert.NotNil(t, captured.ClientInfo, "Client info should be set") + }) + }) + }) + + t.Run("verifies context values with empty variables", func(t *testing.T) { + t.Parallel() + + resultsChan := make(chan verifyModule.CapturedOperationValues, 1) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]any{ + "verifyOperationContextValues": verifyModule.VerifyOperationContextValuesModule{ + ResultsChan: resultsChan, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&verifyModule.VerifyOperationContextValuesModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Send a simple query without variables + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query SimpleQuery { employees { id } }`, + OperationName: json.RawMessage(`"SimpleQuery"`), + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + + // Wait for the module to capture the operation context values + testenv.AwaitChannelWithT(t, 10*time.Second, resultsChan, func(t *testing.T, captured verifyModule.CapturedOperationValues) { + // Verify operation name + assert.Equal(t, "SimpleQuery", captured.Name, "Operation name should be set correctly") + + // Verify operation type + assert.Equal(t, "query", captured.Type, "Operation type should be 'query'") + + // Verify Variables() method works with empty variables + assert.NotNil(t, captured.Variables, "Variables should not be nil even when empty") + + // Verify the variables JSON is an empty object + assert.JSONEq(t, `{}`, captured.VariablesJSON, "Variables JSON should be empty object when no variables provided") + }) + }) + }) + + t.Run("verifies context values with mutation", func(t *testing.T) { + t.Parallel() + + resultsChan := make(chan verifyModule.CapturedOperationValues, 1) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]any{ + "verifyOperationContextValues": verifyModule.VerifyOperationContextValuesModule{ + ResultsChan: resultsChan, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&verifyModule.VerifyOperationContextValuesModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Send a mutation with variables + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployee($empId: Int!, $newTag: String!) { + updateEmployeeTag(id: $empId, tag: $newTag) { + id + tag + } + }`, + Variables: json.RawMessage(`{"empId": 1, "newTag": "Updated by test"}`), + OperationName: json.RawMessage(`"UpdateEmployee"`), + }) + require.NoError(t, err) + assert.Equal(t, 200, res.Response.StatusCode) + + // Wait for the module to capture the operation context values + testenv.AwaitChannelWithT(t, 10*time.Second, resultsChan, func(t *testing.T, captured verifyModule.CapturedOperationValues) { + // Verify operation name + assert.Equal(t, "UpdateEmployee", captured.Name, "Operation name should be set correctly") + + // Verify operation type is mutation + assert.Equal(t, "mutation", captured.Type, "Operation type should be 'mutation'") + + // Verify Variables() method returns the correct mutation variables + assert.NotNil(t, captured.Variables, "Variables should not be nil") + + // Verify the variables contain the mutation input + empIdVar := captured.Variables.Get("empId") + assert.NotNil(t, empIdVar, "Should be able to access 'empId' variable") + assert.Equal(t, 1, empIdVar.GetInt(), "empId variable should be 1") + + newTagVar := captured.Variables.Get("newTag") + assert.NotNil(t, newTagVar, "Should be able to access 'newTag' variable") + + // Verify string variable + newTagBytes := newTagVar.GetStringBytes() + assert.Equal(t, "Updated by test", string(newTagBytes), "newTag should be 'Updated by test'") + }) + }) + }) +} diff --git a/router/core/context.go b/router/core/context.go index 1cafc24ecd..fee849f778 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -3,7 +3,6 @@ package core import ( "context" "errors" - rcontext "github.com/wundergraph/cosmo/router/internal/context" "net/http" "net/url" "strconv" @@ -11,6 +10,8 @@ import ( "sync" "time" + rcontext "github.com/wundergraph/cosmo/router/internal/context" + "go.opentelemetry.io/otel/attribute" "go.uber.org/zap" @@ -469,6 +470,8 @@ type OperationContext interface { Hash() uint64 // Content is the content of the operation Content() string + // Variables is the variables of the operation + Variables() *astjson.Value // ClientInfo returns information about the client that initiated this operation ClientInfo() ClientInfo // QueryPlanStats returns some statistics about the query plan for the operation