diff --git a/router-tests/modules/sha256-verifier/module.go b/router-tests/modules/sha256-verifier/module.go new file mode 100644 index 0000000000..8623a78b13 --- /dev/null +++ b/router-tests/modules/sha256-verifier/module.go @@ -0,0 +1,50 @@ +package sha256_verifier + +import ( + "net/http" + + "github.com/wundergraph/cosmo/router/core" +) + +const myModuleID = "sha256VerifierModule" + +// ResultContainer holds the SHA256 result, shared across module instances +type ResultContainer struct { + Sha256Result string +} + +// Sha256VerifierModule is a simple module that has access to the GraphQL operation and adds custom scopes to the response +type Sha256VerifierModule struct { + ForceSha256 bool + ResultContainer *ResultContainer +} + +func (m *Sha256VerifierModule) Middleware(ctx core.RequestContext, next http.Handler) { + m.ResultContainer.Sha256Result = ctx.Operation().Sha256Hash() + next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) +} + +func (m *Sha256VerifierModule) RouterOnRequest(ctx core.RequestContext, next http.Handler) { + if m.ForceSha256 { + ctx.SetForceSha256Compute() + } + next.ServeHTTP(ctx.ResponseWriter(), ctx.Request()) +} + +func (m *Sha256VerifierModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &Sha256VerifierModule{} + }, + } +} + +// Interface guard +var ( + _ core.RouterMiddlewareHandler = (*Sha256VerifierModule)(nil) + _ core.RouterOnRequestHandler = (*Sha256VerifierModule)(nil) +) diff --git a/router-tests/modules/sha256_verifier_test.go b/router-tests/modules/sha256_verifier_test.go new file mode 100644 index 0000000000..93a136ba94 --- /dev/null +++ b/router-tests/modules/sha256_verifier_test.go @@ -0,0 +1,164 @@ +package module_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + sha256_verifier "github.com/wundergraph/cosmo/router-tests/modules/sha256-verifier" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestSha256VerifierModule(t *testing.T) { + t.Parallel() + + t.Run("verify Sha256Hash is not captured when sha256 force is not enabled", func(t *testing.T) { + t.Parallel() + + resultContainer := &sha256_verifier.ResultContainer{} + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "sha256VerifierModule": sha256_verifier.Sha256VerifierModule{ + ForceSha256: false, + ResultContainer: resultContainer, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + }) + require.NoError(t, err) + require.Equal(t, 200, res.Response.StatusCode) + + require.Empty(t, resultContainer.Sha256Result) + }) + }) + + t.Run("verify sha256Hash is captured from operation when force is enabled", func(t *testing.T) { + t.Parallel() + + resultContainer := &sha256_verifier.ResultContainer{} + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "sha256VerifierModule": sha256_verifier.Sha256VerifierModule{ + ForceSha256: true, + ResultContainer: resultContainer, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query MyQuery { employees { id } }`, + OperationName: json.RawMessage(`"MyQuery"`), + }) + require.NoError(t, err) + require.Equal(t, 200, res.Response.StatusCode) + + require.NotEmpty(t, resultContainer.Sha256Result) + require.Equal(t, "f037469b9c85bb28ae4c13e1d51c1f7e3333ecbe3c28b877c8659a52378f56c0", resultContainer.Sha256Result) + }) + }) + + t.Run("verify different queries produces different Sha256Hashes", func(t *testing.T) { + t.Parallel() + + resultContainer := &sha256_verifier.ResultContainer{} + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "sha256VerifierModule": sha256_verifier.Sha256VerifierModule{ + ForceSha256: true, + ResultContainer: resultContainer, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + _, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query ConsistentQuery { employees { id } }`, + OperationName: json.RawMessage(`"ConsistentQuery"`), + }) + require.NoError(t, err) + firstHash := resultContainer.Sha256Result + require.NotEmpty(t, firstHash) + + _, err = xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query ConsistentQuery { employees { id tag } }`, + OperationName: json.RawMessage(`"ConsistentQuery"`), + }) + require.NoError(t, err) + secondHash := resultContainer.Sha256Result + require.NotEmpty(t, secondHash) + + require.NotEqual(t, firstHash, secondHash) + }) + }) + + t.Run("verify the same query produces same Sha256Hash", func(t *testing.T) { + t.Parallel() + + resultContainer := &sha256_verifier.ResultContainer{} + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "sha256VerifierModule": sha256_verifier.Sha256VerifierModule{ + ForceSha256: true, + ResultContainer: resultContainer, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&sha256_verifier.Sha256VerifierModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + _, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query ConsistentQuery { employees { id } }`, + OperationName: json.RawMessage(`"ConsistentQuery"`), + }) + require.NoError(t, err) + firstHash := resultContainer.Sha256Result + require.NotEmpty(t, firstHash) + + _, err = xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query ConsistentQuery { employees { id } }`, + OperationName: json.RawMessage(`"ConsistentQuery"`), + }) + require.NoError(t, err) + secondHash := resultContainer.Sha256Result + require.NotEmpty(t, secondHash) + + require.Equal(t, firstHash, secondHash, "Same query should produce the same SHA256 hash") + }) + }) + +} diff --git a/router/core/context.go b/router/core/context.go index 235759b6ac..4cbde8706a 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -131,9 +131,14 @@ type RequestContext interface { // SetAuthenticationScopes sets the scopes for the request on Authentication // If Authentication is not set, it will be initialized with the scopes SetAuthenticationScopes(scopes []string) + // SetCustomFieldValueRenderer overrides the default field value rendering behavior // This can be used, e.g. to obfuscate sensitive data in the response SetCustomFieldValueRenderer(renderer resolve.FieldValueRenderer) + + // SetForceSha256Compute forces the computation of the Sha256Hash of the operation + // This is useful if the Sha256Hash is needed in custom modules but not used anywhere else + SetForceSha256Compute() } var metricAttrsPool = sync.Pool{ @@ -263,6 +268,8 @@ type requestContext struct { expressionContext expr.Context // customFieldValueRenderer is used to override the default field value rendering behavior customFieldValueRenderer resolve.FieldValueRenderer + // forceSha256Compute indicates whether the Sha256Hash of the operation should definitely be computed + forceSha256Compute bool } func (c *requestContext) SetCustomFieldValueRenderer(renderer resolve.FieldValueRenderer) { @@ -462,6 +469,10 @@ func (c *requestContext) SetAuthenticationScopes(scopes []string) { auth.SetScopes(scopes) } +func (c *requestContext) SetForceSha256Compute() { + c.forceSha256Compute = true +} + type OperationContext interface { // Name is the name of the operation Name() string @@ -475,6 +486,12 @@ type OperationContext interface { Variables() *astjson.Value // ClientInfo returns information about the client that initiated this operation ClientInfo() ClientInfo + + // Sha256Hash returns the SHA256 hash of the original operation + // It is important to note that this hash is not calculated just because this method has been called + // and is only calculated based on other existing logic (such as if sha256Hash is used in expressions) + Sha256Hash() string + // QueryPlanStats returns some statistics about the query plan for the operation // if called too early in request chain, it may be inaccurate for modules, using // in Middleware is recommended @@ -576,6 +593,10 @@ func (o *operationContext) ClientInfo() ClientInfo { return *o.clientInfo } +func (o *operationContext) Sha256Hash() string { + return o.sha256Hash +} + type QueryPlanStats struct { TotalSubgraphFetches int SubgraphFetches map[string]int diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index cd0b9020a9..bbc3473034 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -439,9 +439,9 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { }) } -func (h *PreHandler) shouldComputeOperationSha256(operationKit *OperationKit) bool { +func (h *PreHandler) shouldComputeOperationSha256(operationKit *OperationKit, reqCtx *requestContext) bool { // If forced, always compute the hash - if h.computeOperationSha256 { + if h.computeOperationSha256 || reqCtx.forceSha256Compute { return true } @@ -523,7 +523,7 @@ func (h *PreHandler) handleOperation(w http.ResponseWriter, req *http.Request, v } // Compute the operation sha256 hash as soon as possible for observability reasons - if h.shouldComputeOperationSha256(operationKit) { + if h.shouldComputeOperationSha256(operationKit, requestContext) { if err := operationKit.ComputeOperationSha256(); err != nil { return &httpGraphqlError{ message: fmt.Sprintf("error hashing operation: %s", err),