diff --git a/router-tests/block_operations_test.go b/router-tests/block_operations_test.go index e7372fb0af..071ecd7a66 100644 --- a/router-tests/block_operations_test.go +++ b/router-tests/block_operations_test.go @@ -123,6 +123,29 @@ func TestBlockOperations(t *testing.T) { }) }) + t.Run("should block operations by parsing time expression", func(t *testing.T) { + t.Parallel() + + // This will verify that parsing time is used to block the operation + expression := "request.operation.parsingTime.Nanoseconds() > 0" + + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockMutations = config.BlockOperationConfiguration{ + Enabled: true, + Condition: expression, + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.JSONEq(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, res.Body) + }) + }) + t.Run("should block operation by scope expression condition", func(t *testing.T) { t.Parallel() diff --git a/router-tests/structured_logging_test.go b/router-tests/structured_logging_test.go index a99b558e73..2f9f03d96b 100644 --- a/router-tests/structured_logging_test.go +++ b/router-tests/structured_logging_test.go @@ -1996,6 +1996,235 @@ func TestFlakyAccessLogs(t *testing.T) { ) }) + t.Run("validate request.operation.sha256Hash expression with persisted hash and body", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "operation_sha256_expression", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.sha256Hash", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, + func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + OperationName: []byte(`"Employees"`), + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`), + Header: map[string][]string{"graphql-client-name": {"my-client"}}, + }) + require.NoError(t, err) + require.JSONEq(t, employeesIDData, res.Body) + + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestLogAll := requestLog.All() + requestContext := requestLogAll[0].ContextMap() + + val, ok := requestContext["operation_sha256_expression"].(string) + require.True(t, ok) + require.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", val) + }, + ) + }) + + t.Run("validate request.operation.sha256Hash expression without persisted operation", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "operation_sha256_expression", + Default: "not-set", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.sha256Hash", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, + func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestLogAll := requestLog.All() + requestContext := requestLogAll[0].ContextMap() + + val, ok := requestContext["operation_sha256_expression"].(string) + require.True(t, ok) + require.Equal(t, "c13e0fafb0a3a72e74c19df743fedee690fe133554a17a9408747585a0d1b423", val) + }, + ) + }) + + t.Run("validate request.operation.persistedId expression set with persisted hash", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "persisted_id_expression", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.persistedId", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, + func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + OperationName: []byte(`"Employees"`), + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`), + Header: map[string][]string{"graphql-client-name": {"my-client"}}, + }) + require.NoError(t, err) + require.JSONEq(t, employeesIDData, res.Body) + + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestLogAll := requestLog.All() + requestContext := requestLogAll[0].ContextMap() + + val, ok := requestContext["persisted_id_expression"].(string) + require.True(t, ok) + require.Equal(t, "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f", val) + }, + ) + + }) + + t.Run("validate request.operation.parsingTime expression > 0", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "parsing_time_expression", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.parsingTime", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestContext := requestLog.All()[0].ContextMap() + val, ok := requestContext["parsing_time_expression"].(time.Duration) + require.True(t, ok) + require.Greater(t, int(val), 0) + }) + }) + + t.Run("validate request.operation.normalizationTime expression > 0", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "normalization_time_expression", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.normalizationTime", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestContext := requestLog.All()[0].ContextMap() + val, ok := requestContext["normalization_time_expression"].(time.Duration) + require.True(t, ok) + require.Greater(t, int(val), 0) + }) + }) + + t.Run("validate request.operation.validationTime expression > 0", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "validation_time_expression", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.validationTime", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestContext := requestLog.All()[0].ContextMap() + val, ok := requestContext["validation_time_expression"].(time.Duration) + require.True(t, ok) + require.Greater(t, int(val), 0) + }) + }) + + t.Run("validate request.operation.planningTime expression > 0", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, + &testenv.Config{ + AccessLogFields: []config.CustomAttribute{ + { + Key: "planning_time_expression", + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.planningTime", + }, + }, + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + requestLog := xEnv.Observer().FilterMessage("/graphql") + requestContext := requestLog.All()[0].ContextMap() + val, ok := requestContext["planning_time_expression"].(time.Duration) + require.True(t, ok) + require.Greater(t, int(val), 0) + }) + }) + t.Run("should be able to use an expression for access logging in feature flags", func(t *testing.T) { t.Parallel() diff --git a/router-tests/telemetry/telemetry_test.go b/router-tests/telemetry/telemetry_test.go index 93a8e751a5..267bbe7d5f 100644 --- a/router-tests/telemetry/telemetry_test.go +++ b/router-tests/telemetry/telemetry_test.go @@ -42,6 +42,11 @@ const ( defaultCosmoRouterMetricsCount = 7 ) +type spanEntry struct { + name string + spanKind trace.SpanKind +} + func TestFlakyEngineStatisticsTelemetry(t *testing.T) { t.Parallel() @@ -9604,7 +9609,537 @@ func TestFlakyTelemetry(t *testing.T) { }) }) + t.Run("verify request.operation expression attributes with dynamic evaluation", func(t *testing.T) { + t.Parallel() + + t.Run("verify sha256Hash expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.sha256Hash", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + + expectedSha256Hash := "da7b196c305087a40625b93073c796f9182e5693ac764fb72050c24f8c6a6071" + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + require.Equal(t, expectedSha256Hash, value.AsString()) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify parsingTime expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "string(request.operation.parsingTime.Nanoseconds())", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: queryHeader, + spanKind: trace.SpanKindServer, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + intVal, err := strconv.Atoi(value.AsString()) + require.NoError(t, err) + require.Greater(t, intVal, 0) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify name expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.name", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + require.Equal(t, queryName, value.AsString()) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify type expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.type", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + require.Equal(t, queryType, value.AsString()) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify persistedId expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.persistedId", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + OperationName: []byte(`"Employees"`), + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`), + Header: map[string][]string{"graphql-client-name": {"my-client"}}, + }) + + persistedID := "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f" + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + { + name: "Load Persisted Operation", + spanKind: trace.SpanKindClient, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 10) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + require.Equal(t, persistedID, value.AsString()) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify normalizationTime expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "string(request.operation.normalizationTime.Nanoseconds())", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + { + name: queryHeader, + spanKind: trace.SpanKindServer, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + intVal, err := strconv.Atoi(value.AsString()) + require.NoError(t, err) + require.Greater(t, intVal, 0) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify hash expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "request.operation.hash", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + hash := "1163600561566987607" + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + { + name: queryHeader, + spanKind: trace.SpanKindServer, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + require.Equal(t, hash, value.AsString()) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify validationTime expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "string(request.operation.validationTime.Nanoseconds())", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Normalize", + spanKind: trace.SpanKindInternal, + }, + { + name: queryHeader, + spanKind: trace.SpanKindServer, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + intVal, err := strconv.Atoi(value.AsString()) + require.NoError(t, err) + require.Greater(t, intVal, 0) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify planningTime expression attribute", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: "string(request.operation.planningTime.Nanoseconds())", + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Normalize", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Validate", + spanKind: trace.SpanKindInternal, + }, + { + name: queryHeader, + spanKind: trace.SpanKindServer, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + intVal, err := strconv.Atoi(value.AsString()) + require.NoError(t, err) + require.Greater(t, intVal, 0) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + + t.Run("verify name and hash expression attributes together", func(t *testing.T) { + t.Parallel() + + exporter := tracetest.NewInMemoryExporter(t) + + key := "custom.attribute" + + testenv.Run(t, &testenv.Config{ + TraceExporter: exporter, + CustomTracingAttributes: []config.CustomAttribute{ + { + Key: key, + ValueFrom: &config.CustomDynamicAttribute{ + Expression: `request.operation.hash + " " + request.operation.name`, + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + queryType := "query" + queryName := "exampleName" + queryHeader := queryType + " " + queryName + xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: queryHeader + ` { employees { id } }`, + }) + + hashAndOperationName := "1163600561566987607" + " " + queryName + + skipSpans := []spanEntry{ + { + name: "HTTP - Read Body", + spanKind: trace.SpanKindInternal, + }, + { + name: "Operation - Parse", + spanKind: trace.SpanKindInternal, + }, + { + name: queryHeader, + spanKind: trace.SpanKindServer, + }, + } + + spans := exporter.GetSpans().Snapshots() + require.Len(t, spans, 9) + + detectedSpanCount := validateDetectedSpans(t, spans, key, skipSpans, func(value attribute.Value) { + require.Equal(t, hashAndOperationName, value.AsString()) + }) + + expected := len(spans) - len(skipSpans) + require.Equal(t, expected, detectedSpanCount) + }) + }) + }) + t.Run("verify attribute expressions with subgraph in the expression", func(t *testing.T) { + t.Parallel() + t.Run("verify subgraph expression should only be present for engine fetch", func(t *testing.T) { t.Parallel() @@ -10561,3 +11096,34 @@ func TestOperationBodyAttributes(t *testing.T) { }) }) } + +func validateDetectedSpans(t *testing.T, sn []sdktrace.ReadOnlySpan, key string, skipSpans []spanEntry, validateFunc func(value attribute.Value)) int { + var detectedSpanCount int + + for _, snapshot := range sn { + attributes := snapshot.Attributes() + snapshot.SpanKind() + + spanSearchEntry := spanEntry{name: snapshot.Name(), spanKind: snapshot.SpanKind()} + value, ok := getAttributeFromKey(attributes, key) + if slices.Contains(skipSpans, spanSearchEntry) { + require.False(t, ok) + continue + } + + require.True(t, ok) + + validateFunc(*value) + detectedSpanCount++ + } + return detectedSpanCount +} + +func getAttributeFromKey(attrs []attribute.KeyValue, key string) (*attribute.Value, bool) { + for _, attr := range attrs { + if string(attr.Key) == key { + return &attr.Value, true + } + } + return nil, false +} diff --git a/router/core/attribute_expressions.go b/router/core/attribute_expressions.go index 11e7c430a9..c63de0c2ea 100644 --- a/router/core/attribute_expressions.go +++ b/router/core/attribute_expressions.go @@ -1,107 +1,125 @@ package core import ( + "context" "fmt" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" "reflect" - "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/vm" "github.com/wundergraph/cosmo/router/internal/expr" "github.com/wundergraph/cosmo/router/pkg/config" "go.opentelemetry.io/otel/attribute" ) -// attributeExpressions maps context attributes to custom attributes. -type attributeExpressions struct { - // expressions is a map of expressions that can be used to resolve dynamic attributes - expressions map[string]*vm.Program - // expressionsWithAuth is a map of expressions that can be used to resolve dynamic attributes and acces the auth - // argument - expressionsWithAuth map[string]*vm.Program - - expressionsWithSubgraph map[string]*vm.Program -} - -type VisitorCheckForRequestAuthAccess struct { - HasAuth bool +type ProgramWithKey struct { + Program *vm.Program + Key string } -func (v *VisitorCheckForRequestAuthAccess) Visit(node *ast.Node) { - if node == nil { - return - } - - if v.HasAuth { - return - } - - switch n := (*node).(type) { - case *ast.MemberNode: - property, propertyOk := n.Property.(*ast.StringNode) - node, nodeOk := n.Node.(*ast.IdentifierNode) - if propertyOk && nodeOk { - if node.Value == expr.ExprRequestKey && property.Value == expr.ExprRequestAuthKey { - v.HasAuth = true - } - } - } +// attributeExpressions maps context attributes to custom attributes. +type attributeExpressions struct { + expressions map[expr.AttributeBucket][]ProgramWithKey } func newAttributeExpressions(attr []config.CustomAttribute, exprManager *expr.Manager) (*attributeExpressions, error) { - attrExprMap := make(map[string]*vm.Program) - attrExprMapWithAuth := make(map[string]*vm.Program) - attrExprMapSubgraph := make(map[string]*vm.Program) + attrs := make(map[expr.AttributeBucket][]ProgramWithKey) for _, a := range attr { if a.ValueFrom != nil && a.ValueFrom.Expression != "" { - usesAuth := VisitorCheckForRequestAuthAccess{} - usesSubgraph := expr.UsesSubgraph{} - prog, err := exprManager.CompileExpression(a.ValueFrom.Expression, reflect.String, &usesAuth, &usesSubgraph) + bucket := expr.RequestOperationBucketVisitor{} + + prog, err := exprManager.CompileExpression(a.ValueFrom.Expression, reflect.String, &bucket) if err != nil { return nil, fmt.Errorf("custom attribute error, unable to compile '%s' with expression '%s': %s", a.Key, a.ValueFrom.Expression, err) } - if usesSubgraph.UsesSubgraph { - attrExprMapSubgraph[a.Key] = prog - } else if usesAuth.HasAuth { - attrExprMapWithAuth[a.Key] = prog - } else { - attrExprMap[a.Key] = prog - } + + attrs[bucket.Bucket] = append(attrs[bucket.Bucket], ProgramWithKey{ + Program: prog, + Key: a.Key, + }) } } return &attributeExpressions{ - expressions: attrExprMap, - expressionsWithAuth: attrExprMapWithAuth, - expressionsWithSubgraph: attrExprMapSubgraph, + expressions: attrs, }, nil } -func expressionAttributes(expressions map[string]*vm.Program, exprCtx *expr.Context) ([]attribute.KeyValue, error) { +func (r *attributeExpressions) expressionsAttributes(exprCtx *expr.Context, key expr.AttributeBucket) ([]attribute.KeyValue, error) { if exprCtx == nil { return nil, nil } + programWrappers, ok := r.expressions[key] + if !ok { + return nil, nil + } + var result []attribute.KeyValue - for exprKey, exprVal := range expressions { - val, err := expr.ResolveStringExpression(exprVal, *exprCtx) + for _, wrapper := range programWrappers { + val, err := expr.ResolveStringExpression(wrapper.Program, *exprCtx) if err != nil { return nil, err } - result = append(result, attribute.String(exprKey, val)) + result = append(result, attribute.String(wrapper.Key, val)) } return result, nil } -func (r *attributeExpressions) expressionsAttributes(exprCtx *expr.Context) ([]attribute.KeyValue, error) { - return expressionAttributes(r.expressions, exprCtx) +type AddExprOpts struct { + logger *zap.Logger + expressions *attributeExpressions + key expr.AttributeBucket + currSpan trace.Span + exprCtx *expr.Context + attrAddFunc func(vals ...attribute.KeyValue) } -func (r *attributeExpressions) expressionsAttributesWithAuth(exprCtx *expr.Context) ([]attribute.KeyValue, error) { - return expressionAttributes(r.expressionsWithAuth, exprCtx) +func setTelemetryAttributes(ctx context.Context, requestContext *requestContext, key expr.AttributeBucket) { + currSpan := trace.SpanFromContext(ctx) + addExpressions(AddExprOpts{ + logger: requestContext.logger, + expressions: requestContext.telemetry.telemetryAttributeExpressions, + key: key, + currSpan: currSpan, + exprCtx: &requestContext.expressionContext, + attrAddFunc: requestContext.telemetry.addCommonAttribute, + }) + + addExpressions(AddExprOpts{ + logger: requestContext.logger, + expressions: requestContext.telemetry.metricAttributeExpressions, + key: key, + exprCtx: &requestContext.expressionContext, + attrAddFunc: requestContext.telemetry.addMetricAttribute, + }) + + addExpressions(AddExprOpts{ + logger: requestContext.logger, + expressions: requestContext.telemetry.tracingAttributeExpressions, + key: key, + currSpan: currSpan, + exprCtx: &requestContext.expressionContext, + attrAddFunc: requestContext.telemetry.addCommonTraceAttribute, + }) } -func (r *attributeExpressions) expressionsAttributesWithSubgraph(exprCtx *expr.Context) ([]attribute.KeyValue, error) { - return expressionAttributes(r.expressionsWithSubgraph, exprCtx) +func addExpressions(opts AddExprOpts) { + if opts.expressions == nil { + return + } + + attributesForKey, err := opts.expressions.expressionsAttributes(opts.exprCtx, opts.key) + if err != nil { + opts.logger.Error("failed to resolve trace attribute", zap.Error(err)) + return + } + + opts.attrAddFunc(attributesForKey...) + if opts.currSpan != nil { + opts.currSpan.SetAttributes(attributesForKey...) + } } diff --git a/router/core/attribute_expressions_test.go b/router/core/attribute_expressions_test.go index 733d237429..7320fb1f99 100644 --- a/router/core/attribute_expressions_test.go +++ b/router/core/attribute_expressions_test.go @@ -73,12 +73,12 @@ func TestVisitorCheckForRequestAuthAccess_Visit(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - v := VisitorCheckForRequestAuthAccess{} + v := expr.RequestOperationBucketVisitor{} manager := expr.CreateNewExprManager() out, err := manager.CompileExpression(tt.expr, reflect.String, &v) assert.NoError(t, err) assert.NotNil(t, out) - assert.Equal(t, tt.expectedHasAuth, v.HasAuth) + assert.Equal(t, tt.expectedHasAuth, v.Bucket == expr.BucketAuth) }) } @@ -104,8 +104,24 @@ func TestNewAttributeExpressions_SplitsExpressionsUsingAuth(t *testing.T) { attrExpr, err := newAttributeExpressions(attrs, manager) assert.NoError(t, err) require.NotNil(t, attrExpr) - assert.Contains(t, attrExpr.expressions, "attr1") - assert.Contains(t, attrExpr.expressionsWithAuth, "attr2") + + assert.Condition(t, func() bool { + for _, it := range attrExpr.expressions[expr.BucketDefault] { + if it.Key == "attr1" { + return true + } + } + return false + }, "expected Key == attr1 in items") + + assert.Condition(t, func() bool { + for _, it := range attrExpr.expressions[expr.BucketAuth] { + if it.Key == "attr2" { + return true + } + } + return false + }, "expected Key == attr2 in items") reqCtx := requestContext{ expressionContext: expr.Context{ @@ -120,12 +136,12 @@ func TestNewAttributeExpressions_SplitsExpressionsUsingAuth(t *testing.T) { }, } - val, err := attrExpr.expressionsAttributes(&reqCtx.expressionContext) + val, err := attrExpr.expressionsAttributes(&reqCtx.expressionContext, expr.BucketDefault) assert.NoError(t, err) require.Len(t, val, 1) assert.Equal(t, "/some/path", val[0].Value.AsString()) - val2, err2 := attrExpr.expressionsAttributesWithAuth(&reqCtx.expressionContext) + val2, err2 := attrExpr.expressionsAttributes(&reqCtx.expressionContext, expr.BucketAuth) assert.NoError(t, err2) require.Len(t, val2, 1) assert.Equal(t, "yes", val2[0].Value.AsString()) diff --git a/router/core/engine_loader_hooks.go b/router/core/engine_loader_hooks.go index e3da6fe26e..c417582237 100644 --- a/router/core/engine_loader_hooks.go +++ b/router/core/engine_loader_hooks.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/wundergraph/cosmo/router/internal/expr" "slices" "sync/atomic" "time" @@ -172,30 +173,36 @@ func (f *engineLoaderHooks) OnFinished(ctx context.Context, ds resolve.DataSourc metricAttrs = append(metricAttrs, reqContext.telemetry.metricAttrs...) metricAttrs = append(metricAttrs, commonAttrs...) - if f.telemetryAttributeExpressions != nil { - telemetryValues, err := f.telemetryAttributeExpressions.expressionsAttributesWithSubgraph(exprCtx) - if err != nil { - reqContext.Logger().Warn("failed to resolve expression for telemetry", zap.Error(err)) - } - traceAttrs = append(traceAttrs, telemetryValues...) - metricAttrs = append(metricAttrs, telemetryValues...) - } - - if f.tracingAttributeExpressions != nil { - tracingValues, err := f.tracingAttributeExpressions.expressionsAttributesWithSubgraph(exprCtx) - if err != nil { - reqContext.Logger().Warn("failed to resolve expression for tracing", zap.Error(err)) - } - traceAttrs = append(traceAttrs, tracingValues...) - } - - if f.metricAttributeExpressions != nil { - metricValues, err := f.metricAttributeExpressions.expressionsAttributesWithSubgraph(exprCtx) - if err != nil { - reqContext.Logger().Warn("failed to resolve expression for metrics", zap.Error(err)) - } - metricAttrs = append(metricAttrs, metricValues...) - } + addExpressions(AddExprOpts{ + logger: reqContext.logger, + expressions: f.telemetryAttributeExpressions, + key: expr.BucketSubgraph, + currSpan: span, + exprCtx: exprCtx, + attrAddFunc: func(telemetryValues ...attribute.KeyValue) { + traceAttrs = append(traceAttrs, telemetryValues...) + metricAttrs = append(metricAttrs, telemetryValues...) + }, + }) + addExpressions(AddExprOpts{ + logger: reqContext.logger, + expressions: f.tracingAttributeExpressions, + key: expr.BucketSubgraph, + currSpan: span, + exprCtx: exprCtx, + attrAddFunc: func(telemetryValues ...attribute.KeyValue) { + traceAttrs = append(traceAttrs, telemetryValues...) + }, + }) + addExpressions(AddExprOpts{ + logger: reqContext.logger, + expressions: f.metricAttributeExpressions, + key: expr.BucketSubgraph, + exprCtx: exprCtx, + attrAddFunc: func(telemetryValues ...attribute.KeyValue) { + metricAttrs = append(metricAttrs, telemetryValues...) + }, + }) metricAddOpt := otelmetric.WithAttributeSet(attribute.NewSet(metricAttrs...)) diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 43687e8bfc..dd21615f70 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -214,37 +214,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { routerSpan.SetAttributes(requestContext.telemetry.traceAttrs...) - if requestContext.telemetry.telemetryAttributeExpressions != nil { - traceMetrics, err := requestContext.telemetry.telemetryAttributeExpressions.expressionsAttributes(&requestContext.expressionContext) - if err != nil { - requestLogger.Error("failed to resolve trace attribute", zap.Error(err)) - } - requestContext.telemetry.addCommonAttribute( - traceMetrics..., - ) - routerSpan.SetAttributes(traceMetrics...) - } - - if requestContext.telemetry.metricAttributeExpressions != nil { - metricAttrs, err := requestContext.telemetry.metricAttributeExpressions.expressionsAttributes(&requestContext.expressionContext) - if err != nil { - requestLogger.Error("failed to resolve metric attribute", zap.Error(err)) - } - requestContext.telemetry.addMetricAttribute( - metricAttrs..., - ) - } - - if requestContext.telemetry.tracingAttributeExpressions != nil { - traceMetrics, err := requestContext.telemetry.tracingAttributeExpressions.expressionsAttributes(&requestContext.expressionContext) - if err != nil { - requestLogger.Error("failed to resolve trace attribute", zap.Error(err)) - } - requestContext.telemetry.addCommonTraceAttribute( - traceMetrics..., - ) - routerSpan.SetAttributes(traceMetrics...) - } + setTelemetryAttributes(r.Context(), requestContext, expr.BucketDefault) requestContext.operation = &operationContext{ clientInfo: clientInfo, @@ -391,37 +361,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context()) } - if requestContext.telemetry.telemetryAttributeExpressions != nil { - traceMetrics, err := requestContext.telemetry.telemetryAttributeExpressions.expressionsAttributesWithAuth(&requestContext.expressionContext) - if err != nil { - requestLogger.Error("failed to resolve trace attribute", zap.Error(err)) - } - requestContext.telemetry.addCommonAttribute( - traceMetrics..., - ) - routerSpan.SetAttributes(traceMetrics...) - } - - if requestContext.telemetry.metricAttributeExpressions != nil { - metricAttrs, err := requestContext.telemetry.metricAttributeExpressions.expressionsAttributesWithAuth(&requestContext.expressionContext) - if err != nil { - requestLogger.Error("failed to resolve metric attribute", zap.Error(err)) - } - requestContext.telemetry.addMetricAttribute( - metricAttrs..., - ) - } - - if requestContext.telemetry.tracingAttributeExpressions != nil { - traceMetrics, err := requestContext.telemetry.tracingAttributeExpressions.expressionsAttributesWithAuth(&requestContext.expressionContext) - if err != nil { - requestLogger.Error("failed to resolve trace attribute", zap.Error(err)) - } - requestContext.telemetry.addCommonTraceAttribute( - traceMetrics..., - ) - routerSpan.SetAttributes(traceMetrics...) - } + setTelemetryAttributes(r.Context(), requestContext, expr.BucketAuth) err = h.handleOperation(r, variablesParser, &httpOperation{ requestContext: requestContext, @@ -483,6 +423,10 @@ func (h *PreHandler) shouldComputeOperationSha256(operationKit *OperationKit) bo return true } + if h.exprManager.VisitorManager.IsRequestOperationSha256UsedInExpressions() { + return true + } + hasPersistedHash := operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.HasHash() // If it has a hash already AND a body, we need to compute the hash again to ensure it matches the persisted hash @@ -565,6 +509,10 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson } } requestContext.operation.sha256Hash = operationKit.parsedOperation.Sha256Hash + requestContext.expressionContext.Request.Operation.Sha256Hash = operationKit.parsedOperation.Sha256Hash + + setTelemetryAttributes(req.Context(), requestContext, expr.BucketSha256) + requestContext.telemetry.addCustomMetricStringAttr(ContextFieldOperationSha256, requestContext.operation.sha256Hash) if h.operationBlocker.safelistEnabled || h.operationBlocker.logUnknownOperationsEnabled { // Set the request hash to the parsed hash, to see if it matches a persisted operation @@ -632,7 +580,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson // because the operation was already parsed. This is a performance optimization, and we // can do it because we know that the persisted operation is immutable (identified by the hash) if !skipParse { - _, engineParseSpan := h.tracer.Start(req.Context(), "Operation - Parse", + parseCtx, engineParseSpan := h.tracer.Start(req.Context(), "Operation - Parse", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) @@ -650,6 +598,9 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson rtrace.AttachErrToSpan(engineParseSpan, err) requestContext.operation.parsingTime = time.Since(startParsing) + requestContext.expressionContext.Request.Operation.ParsingTime = requestContext.operation.parsingTime + setTelemetryAttributes(parseCtx, requestContext, expr.BucketParsingTime) + if !requestContext.operation.traceOptions.ExcludeParseStats { httpOperation.traceTimings.EndParse() } @@ -660,6 +611,9 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson } requestContext.operation.parsingTime = time.Since(startParsing) + requestContext.expressionContext.Request.Operation.ParsingTime = requestContext.operation.parsingTime + setTelemetryAttributes(parseCtx, requestContext, expr.BucketParsingTime) + if !requestContext.operation.traceOptions.ExcludeParseStats { httpOperation.traceTimings.EndParse() } @@ -670,7 +624,11 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson requestContext.operation.name = operationKit.parsedOperation.Request.OperationName requestContext.operation.opType = operationKit.parsedOperation.Type - setExpressionContextOperation(requestContext) + requestContext.expressionContext.Request.Operation.Name = requestContext.operation.name + requestContext.expressionContext.Request.Operation.Type = requestContext.operation.opType + + setTelemetryAttributes(req.Context(), requestContext, expr.BucketNameOrType) + setExpressionContextClient(requestContext) attributesAfterParse := []attribute.KeyValue{ @@ -723,6 +681,9 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson if operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.HasHash() { hash := operationKit.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash requestContext.operation.persistedID = hash + requestContext.expressionContext.Request.Operation.PersistedID = hash + setTelemetryAttributes(req.Context(), requestContext, expr.BucketPersistedID) + persistedIDAttribute := otel.WgOperationPersistedID.String(hash) requestContext.telemetry.addCommonAttribute(persistedIDAttribute) @@ -740,7 +701,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson startNormalization := time.Now() - _, engineNormalizeSpan := h.tracer.Start(req.Context(), "Operation - Normalize", + normalizeCtx, engineNormalizeSpan := h.tracer.Start(req.Context(), "Operation - Normalize", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) @@ -750,6 +711,9 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson rtrace.AttachErrToSpan(engineNormalizeSpan, err) requestContext.operation.normalizationTime = time.Since(startNormalization) + requestContext.expressionContext.Request.Operation.NormalizationTime = requestContext.operation.normalizationTime + setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketNormalizationTime) + if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() } @@ -779,6 +743,8 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson rtrace.AttachErrToSpan(engineNormalizeSpan, err) requestContext.operation.normalizationTime = time.Since(startNormalization) + requestContext.expressionContext.Request.Operation.NormalizationTime = requestContext.operation.normalizationTime + setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketNormalizationTime) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() @@ -821,6 +787,8 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson rtrace.AttachErrToSpan(engineNormalizeSpan, err) requestContext.operation.normalizationTime = time.Since(startNormalization) + requestContext.expressionContext.Request.Operation.NormalizationTime = requestContext.operation.normalizationTime + setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketNormalizationTime) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() @@ -839,7 +807,6 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson if requestContext.operation.hash != 0 { operationHash = requestContext.operation.HashString() } - requestContext.expressionContext.Request.Operation.Hash = operationHash if !h.disableVariablesRemapping && len(uploadsMapping) > 0 { // after variables remapping we need to update the file uploads path because variables relative path has changed @@ -897,6 +864,11 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson return err } requestContext.operation.normalizationTime = time.Since(startNormalization) + requestContext.expressionContext.Request.Operation.NormalizationTime = requestContext.operation.normalizationTime + setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketNormalizationTime) + + requestContext.expressionContext.Request.Operation.Hash = operationHash + setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketHash) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { httpOperation.traceTimings.EndNormalize() @@ -907,12 +879,12 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson engineNormalizeSpan.SetAttributes(otel.WgOperationNormalizedContent.String(operationKit.parsedOperation.NormalizedRepresentation)) } - engineNormalizeSpan.End() - if operationKit.parsedOperation.IsPersistedOperation { engineNormalizeSpan.SetAttributes(otel.WgEnginePersistedOperationCacheHit.Bool(operationKit.parsedOperation.PersistedOperationCacheHit)) } + engineNormalizeSpan.End() + if h.traceExportVariables { // At this stage the variables are normalized httpOperation.routerSpan.SetAttributes(otel.WgOperationVariables.String(string(operationKit.parsedOperation.Request.Variables))) @@ -928,7 +900,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson startValidation := time.Now() - _, engineValidateSpan := h.tracer.Start(req.Context(), "Operation - Validate", + validationCtx, engineValidateSpan := h.tracer.Start(req.Context(), "Operation - Validate", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) @@ -946,6 +918,9 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson rtrace.AttachErrToSpan(engineValidateSpan, err) requestContext.operation.validationTime = time.Since(startValidation) + requestContext.expressionContext.Request.Operation.ValidationTime = requestContext.operation.validationTime + setTelemetryAttributes(validationCtx, requestContext, expr.BucketValidationTime) + httpOperation.traceTimings.EndValidate() engineValidateSpan.End() @@ -960,6 +935,8 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson requestContext.graphQLErrorCodes = append(requestContext.graphQLErrorCodes, h.getErrorCodes(err)...) requestContext.operation.validationTime = time.Since(startValidation) + requestContext.expressionContext.Request.Operation.ValidationTime = requestContext.operation.validationTime + setTelemetryAttributes(validationCtx, requestContext, expr.BucketValidationTime) if !requestContext.operation.traceOptions.ExcludeValidateStats { httpOperation.traceTimings.EndValidate() @@ -979,6 +956,9 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson } requestContext.operation.validationTime = time.Since(startValidation) + requestContext.expressionContext.Request.Operation.ValidationTime = requestContext.operation.validationTime + setTelemetryAttributes(validationCtx, requestContext, expr.BucketValidationTime) + httpOperation.traceTimings.EndValidate() engineValidateSpan.End() @@ -996,7 +976,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson startPlanning := time.Now() - _, enginePlanSpan := h.tracer.Start(req.Context(), "Operation - Plan", + planCtx, enginePlanSpan := h.tracer.Start(req.Context(), "Operation - Plan", trace.WithSpanKind(trace.SpanKindInternal), trace.WithAttributes(otel.WgEngineRequestTracingEnabled.Bool(requestContext.operation.traceOptions.Enable)), trace.WithAttributes(requestContext.telemetry.traceAttrs...), @@ -1018,6 +998,8 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson } requestContext.operation.planningTime = time.Since(startPlanning) + requestContext.expressionContext.Request.Operation.PlanningTime = requestContext.operation.planningTime + setTelemetryAttributes(planCtx, requestContext, expr.BucketPlanningTime) rtrace.AttachErrToSpan(enginePlanSpan, err) enginePlanSpan.End() @@ -1030,6 +1012,8 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson } requestContext.operation.planningTime = time.Since(startPlanning) + requestContext.expressionContext.Request.Operation.PlanningTime = requestContext.operation.planningTime + setTelemetryAttributes(planCtx, requestContext, expr.BucketPlanningTime) enginePlanSpan.SetAttributes(otel.WgEnginePlanCacheHit.Bool(requestContext.operation.planCacheHit)) enginePlanSpan.End() @@ -1195,13 +1179,6 @@ func (h *PreHandler) parseRequestExecutionOptions(r *http.Request) resolve.Execu return options } -func setExpressionContextOperation(requestContext *requestContext) { - requestContext.expressionContext.Request.Operation = expr.Operation{ - Name: requestContext.operation.name, - Type: requestContext.operation.opType, - } -} - func setExpressionContextClient(requestContext *requestContext) { clientName := requestContext.operation.clientInfo.Name if clientName == "unknown" { @@ -1214,9 +1191,7 @@ func setExpressionContextClient(requestContext *requestContext) { } if clientName != "" || clientVersion != "" { - requestContext.expressionContext.Request.Client = expr.Client{ - Name: clientName, - Version: clientVersion, - } + requestContext.expressionContext.Request.Client.Name = clientName + requestContext.expressionContext.Request.Client.Version = clientVersion } } diff --git a/router/core/transport.go b/router/core/transport.go index d972941e83..c4dbf61054 100644 --- a/router/core/transport.go +++ b/router/core/transport.go @@ -135,7 +135,7 @@ func (ct *CustomTransport) measureSubgraphMetrics(req *http.Request) func(err er attributes = append(attributes, reqContext.telemetry.metricAttrs...) if reqContext.telemetry.metricAttributeExpressions != nil { - additionalAttrs, err := reqContext.telemetry.metricAttributeExpressions.expressionsAttributes(&reqContext.expressionContext) + additionalAttrs, err := reqContext.telemetry.metricAttributeExpressions.expressionsAttributes(&reqContext.expressionContext, expr.BucketDefault) if err != nil { ct.logger.Error("failed to resolve metric attribute expressions", zap.Error(err)) } @@ -487,21 +487,25 @@ func CreateGRPCTraceGetter( attrs = append(attrs, traceAttrs...) attrs = append(attrs, reqCtx.telemetry.traceAttrs...) - if telemetryAttributeExpressions != nil { - telemetryValues, err := telemetryAttributeExpressions.expressionsAttributesWithSubgraph(&reqCtx.expressionContext) - if err != nil { - reqCtx.Logger().Warn("failed to resolve grpc plugin expression for telemetry", zap.Error(err)) - } + spanAttrFunc := func(telemetryValues ...attribute.KeyValue) { attrs = append(attrs, telemetryValues...) } - if tracingAttributeExpressions != nil { - tracingValues, err := tracingAttributeExpressions.expressionsAttributesWithSubgraph(&reqCtx.expressionContext) - if err != nil { - reqCtx.Logger().Warn("failed to resolve grpc plugin expression for tracing", zap.Error(err)) - } - attrs = append(attrs, tracingValues...) - } + addExpressions(AddExprOpts{ + logger: reqCtx.logger, + expressions: telemetryAttributeExpressions, + key: expr.BucketSubgraph, + exprCtx: &reqCtx.expressionContext, + attrAddFunc: spanAttrFunc, + }) + + addExpressions(AddExprOpts{ + logger: reqCtx.logger, + expressions: tracingAttributeExpressions, + key: expr.BucketSubgraph, + exprCtx: &reqCtx.expressionContext, + attrAddFunc: spanAttrFunc, + }) // Override http operation protocol with grpc attrs = append(attrs, otel.EngineTransportAttribute, otel.WgOperationProtocol.String(OperationProtocolGRPC.String())) diff --git a/router/internal/expr/expr.go b/router/internal/expr/expr.go index 76bd017427..64f3324958 100644 --- a/router/internal/expr/expr.go +++ b/router/internal/expr/expr.go @@ -79,9 +79,15 @@ type Response struct { } type Operation struct { - Name string `expr:"name"` - Type string `expr:"type"` - Hash string `expr:"hash"` + Sha256Hash string `expr:"sha256Hash"` + ParsingTime time.Duration `expr:"parsingTime"` + Name string `expr:"name"` + Type string `expr:"type"` + PersistedID string `expr:"persistedId"` + NormalizationTime time.Duration `expr:"normalizationTime"` + Hash string `expr:"hash"` + ValidationTime time.Duration `expr:"validationTime"` + PlanningTime time.Duration `expr:"planningTime"` } type Client struct { diff --git a/router/internal/expr/request_operation_bucket_visitor.go b/router/internal/expr/request_operation_bucket_visitor.go new file mode 100644 index 0000000000..9e64106683 --- /dev/null +++ b/router/internal/expr/request_operation_bucket_visitor.go @@ -0,0 +1,116 @@ +package expr + +import ( + "github.com/expr-lang/expr/ast" +) + +// AttributeBucket indicates the highest-priority usage detected in an expression +type AttributeBucket uint8 + +const ( + BucketDefault AttributeBucket = iota + BucketAuth + BucketSha256 + BucketParsingTime + BucketNameOrType + BucketPersistedID + BucketNormalizationTime + BucketHash + BucketValidationTime + BucketPlanningTime + BucketSubgraph +) + +// RequestOperationBucketVisitor inspects nodes and sets Bucket to the highest-priority match +// Priority (low -> high): any, auth, sha256, parsingTime, name/type, persistedId, normalizationTime, +// hash, validationTime, planningTime, subgraph +type RequestOperationBucketVisitor struct { + Bucket AttributeBucket +} + +func (v *RequestOperationBucketVisitor) Visit(baseNode *ast.Node) { + if baseNode == nil || v.Bucket == BucketSubgraph { + return + } + + // Detect subgraph usage (highest priority) + if ident, ok := (*baseNode).(*ast.IdentifierNode); ok { + if ident.Value == "subgraph" { + v.setBucketIfHigher(BucketSubgraph) + return + } + } + + if member, ok := (*baseNode).(*ast.MemberNode); ok { + // subgraph.* also qualifies + if ident, ok := member.Node.(*ast.IdentifierNode); ok && ident.Value == "subgraph" { + v.setBucketIfHigher(BucketSubgraph) + return + } + + // request.auth (lowest priority) + prop := getPropValue(member) + if prop == "" { + return + } + + if prop == "auth" { + if reqIdent, ok := member.Node.(*ast.IdentifierNode); ok && reqIdent.Value == ExprRequestKey { + v.setBucketIfHigher(BucketAuth) + // don't return as higher-priority matches may exist in other nodes as child nodes + } + } + + // Ensure parent is request.operation + opMember, ok := member.Node.(*ast.MemberNode) + if !ok { + return + } + + opProp := getPropValue(opMember) + if opProp != "operation" { + return + } + + if reqIdent, ok := opMember.Node.(*ast.IdentifierNode); !ok || reqIdent.Value != ExprRequestKey { + return + } + + // Map property to bucket + switch prop { + case "sha256Hash": + v.setBucketIfHigher(BucketSha256) + case "parsingTime": + v.setBucketIfHigher(BucketParsingTime) + case "name", "type": + v.setBucketIfHigher(BucketNameOrType) + case "persistedId": + v.setBucketIfHigher(BucketPersistedID) + case "normalizationTime": + v.setBucketIfHigher(BucketNormalizationTime) + case "hash": + v.setBucketIfHigher(BucketHash) + case "validationTime": + v.setBucketIfHigher(BucketValidationTime) + case "planningTime": + v.setBucketIfHigher(BucketPlanningTime) + } + } +} + +func (v *RequestOperationBucketVisitor) setBucketIfHigher(bucket AttributeBucket) { + if bucket > v.Bucket { + v.Bucket = bucket + } +} + +func getPropValue(member *ast.MemberNode) string { + prop := "" + switch p := member.Property.(type) { + case *ast.StringNode: + prop = p.Value + case *ast.IdentifierNode: + prop = p.Value + } + return prop +} diff --git a/router/internal/expr/request_operation_bucket_visitor_test.go b/router/internal/expr/request_operation_bucket_visitor_test.go new file mode 100644 index 0000000000..9e4989b09e --- /dev/null +++ b/router/internal/expr/request_operation_bucket_visitor_test.go @@ -0,0 +1,421 @@ +package expr + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRequestOperationBucketVisitor validates that expressions are correctly classified into buckets +// based on the attributes they access. +// +// Priority (low → high): Default < Auth < Sha256 < ParsingTime < NameOrType < PersistedID < +// NormalizationTime < Hash < ValidationTime < PlanningTime < Subgraph + +func TestRequestOperationBucketVisitor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expression string + expectedBucket AttributeBucket + description string + }{ + // BucketDefault - no matching attributes + { + name: "no operation attributes", + expression: `"static value"`, + expectedBucket: BucketDefault, + description: "Expression with no request attributes should use default bucket", + }, + { + name: "request without operation", + expression: `request.url.method == "POST"`, + expectedBucket: BucketDefault, + description: "Request attributes other than operation should use default bucket", + }, + + // BucketAuth - request.auth (lowest priority operation attribute) + { + name: "request.auth.claims", + expression: `request.auth.claims["sub"] == "user123"`, + expectedBucket: BucketAuth, + description: "Auth claims access should use auth bucket", + }, + { + name: "request.auth.scopes", + expression: `"admin" in request.auth.scopes`, + expectedBucket: BucketAuth, + description: "Auth scopes access should use auth bucket", + }, + { + name: "request.auth with string property", + expression: `request["auth"]["claims"]`, + expectedBucket: BucketAuth, + description: "Auth access with bracket notation should use auth bucket", + }, + + // BucketSha256 - request.operation.sha256Hash + { + name: "sha256Hash", + expression: `request.operation.sha256Hash == "abc123"`, + expectedBucket: BucketSha256, + description: "SHA256 hash access should use sha256 bucket", + }, + { + name: "sha256Hash with bracket notation", + expression: `request["operation"]["sha256Hash"]`, + expectedBucket: BucketSha256, + description: "SHA256 hash with bracket notation should use sha256 bucket", + }, + { + name: "sha256Hash in condition", + expression: `request.operation.sha256Hash != "" && request.url.method == "POST"`, + expectedBucket: BucketSha256, + description: "SHA256 hash in complex condition should use sha256 bucket", + }, + + // BucketParsingTime - request.operation.parsingTime + { + name: "parsingTime", + expression: `request.operation.parsingTime`, + expectedBucket: BucketParsingTime, + description: "Parsing time access should use parsing time bucket", + }, + { + name: "parsingTime with bracket", + expression: `request.operation["parsingTime"]`, + expectedBucket: BucketParsingTime, + description: "Parsing time with bracket notation should use parsing time bucket", + }, + + // BucketNameOrType - request.operation.name or request.operation.type + { + name: "operation name", + expression: `request.operation.name == "GetUser"`, + expectedBucket: BucketNameOrType, + description: "Operation name access should use name/type bucket", + }, + { + name: "operation type", + expression: `request.operation.type == "query"`, + expectedBucket: BucketNameOrType, + description: "Operation type access should use name/type bucket", + }, + { + name: "operation name in conditional", + expression: `request.operation.name != "" ? "named" : "anonymous"`, + expectedBucket: BucketNameOrType, + description: "Operation name in ternary should use name/type bucket", + }, + + // BucketPersistedID - request.operation.persistedId + { + name: "persistedId", + expression: `request.operation.persistedId == "abc123"`, + expectedBucket: BucketPersistedID, + description: "Persisted ID access should use persisted ID bucket", + }, + { + name: "persistedId existence check", + expression: `request.operation.persistedId != ""`, + expectedBucket: BucketPersistedID, + description: "Persisted ID check should use persisted ID bucket", + }, + + // BucketNormalizationTime - request.operation.normalizationTime + { + name: "normalizationTime", + expression: `request.operation.normalizationTime`, + expectedBucket: BucketNormalizationTime, + description: "Normalization time access should use normalization time bucket", + }, + { + name: "normalizationTime comparison", + expression: `request.operation.normalizationTime < request.operation.parsingTime`, + expectedBucket: BucketNormalizationTime, + description: "Normalization time is higher priority than parsing time", + }, + + // BucketHash - request.operation.hash + { + name: "operation hash", + expression: `request.operation.hash == "xyz789"`, + expectedBucket: BucketHash, + description: "Operation hash access should use hash bucket", + }, + { + name: "hash with bracket notation", + expression: `request["operation"]["hash"]`, + expectedBucket: BucketHash, + description: "Hash with bracket notation should use hash bucket", + }, + + // BucketValidationTime - request.operation.validationTime + { + name: "validationTime", + expression: `request.operation.validationTime`, + expectedBucket: BucketValidationTime, + description: "Validation time access should use validation time bucket", + }, + { + name: "validationTime vs hash priority", + expression: `request.operation.validationTime > request.operation.normalizationTime && request.operation.hash != ""`, + expectedBucket: BucketValidationTime, + description: "Validation time is higher priority than hash", + }, + + // BucketPlanningTime - request.operation.planningTime + { + name: "planningTime", + expression: `request.operation.planningTime`, + expectedBucket: BucketPlanningTime, + description: "Planning time access should use planning time bucket", + }, + { + name: "planningTime comparison", + expression: `request.operation.planningTime + request.operation.validationTime`, + expectedBucket: BucketPlanningTime, + description: "Planning time is higher priority than validation time", + }, + + // BucketSubgraph - subgraph or subgraph.* (highest priority) + { + name: "subgraph identifier", + expression: `subgraph`, + expectedBucket: BucketSubgraph, + description: "Direct subgraph reference should use subgraph bucket", + }, + { + name: "subgraph.name", + expression: `subgraph.name == "products"`, + expectedBucket: BucketSubgraph, + description: "Subgraph property access should use subgraph bucket", + }, + { + name: "subgraph in condition", + expression: `subgraph.name == "users" && request.url.method == "POST"`, + expectedBucket: BucketSubgraph, + description: "Subgraph in condition should use subgraph bucket", + }, + { + name: "subgraph vs all operation attributes", + expression: `subgraph.name + request.operation.hash + request.operation.name`, + expectedBucket: BucketSubgraph, + description: "Subgraph is highest priority even with other attributes", + }, + + // Priority tests - multiple attributes with different priorities + { + name: "auth and sha256 - sha256 wins", + expression: `request.auth.claims["sub"] == "user" && request.operation.sha256Hash == "abc"`, + expectedBucket: BucketSha256, + description: "SHA256 should win over auth (higher priority)", + }, + { + name: "sha256 and name - name wins", + expression: `request.operation.sha256Hash + request.operation.name`, + expectedBucket: BucketNameOrType, + description: "Name should win over sha256 (higher priority)", + }, + { + name: "name and persistedId - persistedId wins", + expression: `request.operation.name == "Query" && request.operation.persistedId != ""`, + expectedBucket: BucketPersistedID, + description: "Persisted ID should win over name (higher priority)", + }, + { + name: "persistedId and hash - hash wins", + expression: `request.operation.persistedId + request.operation.hash`, + expectedBucket: BucketHash, + description: "Hash should win over persisted ID (higher priority)", + }, + { + name: "hash and validationTime - validationTime wins", + expression: `request.operation.hash == "xyz" && request.operation.validationTime > request.operation.parsingTime`, + expectedBucket: BucketValidationTime, + description: "Validation time should win over hash (higher priority)", + }, + { + name: "validationTime and planningTime - planningTime wins", + expression: `request.operation.validationTime + request.operation.planningTime`, + expectedBucket: BucketPlanningTime, + description: "Planning time should win over validation time (higher priority)", + }, + { + name: "planningTime and subgraph - subgraph wins", + expression: `request.operation.planningTime > request.operation.validationTime && subgraph.name != ""`, + expectedBucket: BucketSubgraph, + description: "Subgraph should win over planning time (highest priority)", + }, + + // Complex expressions + { + name: "nested conditional with multiple attributes", + expression: `request.operation.type == "mutation" ? request.operation.name : request.auth.claims["sub"]`, + expectedBucket: BucketNameOrType, + description: "Name/type should win in nested conditional with auth", + }, + { + name: "complex boolean expression", + expression: `(request.operation.parsingTime > request.operation.validationTime) || (request.operation.planningTime > request.operation.parsingTime)`, + expectedBucket: BucketPlanningTime, + description: "Planning time should be detected in complex boolean expression", + }, + { + name: "string concatenation", + expression: `request.operation.name + "-" + request.operation.hash`, + expectedBucket: BucketHash, + description: "Hash should win in string concatenation with name", + }, + + // Edge cases + { + name: "operation without specific property", + expression: `request.operation`, + expectedBucket: BucketDefault, + description: "Operation without property access should use default bucket", + }, + { + name: "auth combined with non-operation", + expression: `request.auth.claims["role"] + request.url.method`, + expectedBucket: BucketAuth, + description: "Auth with non-operation attributes should use auth bucket", + }, + { + name: "mixed identifier and string property access", + expression: `request["operation"].sha256Hash`, + expectedBucket: BucketSha256, + description: "Mixed bracket and dot notation should work for sha256", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a new expression manager + exprManager := CreateNewExprManager() + + // Create the visitor + visitor := &RequestOperationBucketVisitor{ + Bucket: BucketDefault, + } + + // Compile the expression with the visitor + _, err := exprManager.CompileAnyExpression(tt.expression, visitor) + require.NoError(t, err, "Failed to compile expression: %s", tt.expression) + + // Assert the bucket matches expected + assert.Equal(t, tt.expectedBucket, visitor.Bucket, + "Description: %s\nExpression: %s\nExpected bucket: %v (priority %d)\nGot bucket: %v (priority %d)", + tt.description, tt.expression, bucketName(tt.expectedBucket), tt.expectedBucket, + bucketName(visitor.Bucket), visitor.Bucket) + }) + } +} + +// bucketName returns a human-readable name for the bucket (for better test output) +func bucketName(bucket AttributeBucket) string { + switch bucket { + case BucketDefault: + return "BucketDefault" + case BucketAuth: + return "BucketAuth" + case BucketSha256: + return "BucketSha256" + case BucketParsingTime: + return "BucketParsingTime" + case BucketNameOrType: + return "BucketNameOrType" + case BucketPersistedID: + return "BucketPersistedID" + case BucketNormalizationTime: + return "BucketNormalizationTime" + case BucketHash: + return "BucketHash" + case BucketValidationTime: + return "BucketValidationTime" + case BucketPlanningTime: + return "BucketPlanningTime" + case BucketSubgraph: + return "BucketSubgraph" + default: + return "Unknown" + } +} + +// TestBucketPriority verifies the priority order is correct +func TestBucketPriority(t *testing.T) { + t.Parallel() + + // This test verifies the priority order defined in the constants + // which would alert in case someone would change it + + assert.True(t, BucketDefault < BucketAuth, "Default should be lower priority than Auth") + assert.True(t, BucketAuth < BucketSha256, "Auth should be lower priority than Sha256") + assert.True(t, BucketSha256 < BucketParsingTime, "Sha256 should be lower priority than ParsingTime") + assert.True(t, BucketParsingTime < BucketNameOrType, "ParsingTime should be lower priority than NameOrType") + assert.True(t, BucketNameOrType < BucketPersistedID, "NameOrType should be lower priority than PersistedID") + assert.True(t, BucketPersistedID < BucketNormalizationTime, "PersistedID should be lower priority than NormalizationTime") + assert.True(t, BucketNormalizationTime < BucketHash, "NormalizationTime should be lower priority than Hash") + assert.True(t, BucketHash < BucketValidationTime, "Hash should be lower priority than ValidationTime") + assert.True(t, BucketValidationTime < BucketPlanningTime, "ValidationTime should be lower priority than PlanningTime") + assert.True(t, BucketPlanningTime < BucketSubgraph, "PlanningTime should be lower priority than Subgraph") +} + +// TestSetBucketIfHigher verifies the setBucketIfHigher logic +func TestSetBucketIfHigher(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + currentBucket AttributeBucket + newBucket AttributeBucket + expectedBucket AttributeBucket + }{ + { + name: "lower priority should not update", + currentBucket: BucketHash, + newBucket: BucketSha256, + expectedBucket: BucketHash, + }, + { + name: "higher priority should update", + currentBucket: BucketSha256, + newBucket: BucketHash, + expectedBucket: BucketHash, + }, + { + name: "same priority should not update", + currentBucket: BucketHash, + newBucket: BucketHash, + expectedBucket: BucketHash, + }, + { + name: "subgraph should always win", + currentBucket: BucketPlanningTime, + newBucket: BucketSubgraph, + expectedBucket: BucketSubgraph, + }, + { + name: "nothing beats subgraph", + currentBucket: BucketSubgraph, + newBucket: BucketPlanningTime, + expectedBucket: BucketSubgraph, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + visitor := &RequestOperationBucketVisitor{ + Bucket: tt.currentBucket, + } + visitor.setBucketIfHigher(tt.newBucket) + assert.Equal(t, tt.expectedBucket, visitor.Bucket) + }) + } +} diff --git a/router/internal/expr/use_request_operation_sha256.go b/router/internal/expr/use_request_operation_sha256.go new file mode 100644 index 0000000000..7b25cb9049 --- /dev/null +++ b/router/internal/expr/use_request_operation_sha256.go @@ -0,0 +1,68 @@ +package expr + +import ( + "github.com/expr-lang/expr/ast" +) + +const ( + sha256HashAttributeName = "sha256Hash" + operationAttributeName = "operation" +) + +// UsesRequestOperationSha256 detects whether an expression references request.operation.sha256Hash +type UsesRequestOperationSha256 struct { + UsesRequestOperationSha256 bool +} + +func (v *UsesRequestOperationSha256) Visit(baseNode *ast.Node) { + if baseNode == nil || v.UsesRequestOperationSha256 { + return + } + + // Check if it's a member access ending with "sha256Hash" + shaAccess, ok := (*baseNode).(*ast.MemberNode) + if !ok { + return + } + + // Property should be "sha256Hash" + switch p := shaAccess.Property.(type) { + case *ast.StringNode: + if p.Value != sha256HashAttributeName { + return + } + case *ast.IdentifierNode: + if p.Value != sha256HashAttributeName { + return + } + default: + return + } + + // Parent should be a member access to "operation" + operationAccess, ok := shaAccess.Node.(*ast.MemberNode) + if !ok { + return + } + + switch op := operationAccess.Property.(type) { + case *ast.StringNode: + if op.Value != operationAttributeName { + return + } + case *ast.IdentifierNode: + if op.Value != operationAttributeName { + return + } + default: + return + } + + // Root should be identifier "request" + requestIdent, ok := operationAccess.Node.(*ast.IdentifierNode) + if !ok || requestIdent.Value != "request" { + return + } + + v.UsesRequestOperationSha256 = true +} diff --git a/router/internal/expr/uses_request_operation_sha256_test.go b/router/internal/expr/uses_request_operation_sha256_test.go new file mode 100644 index 0000000000..4111663de6 --- /dev/null +++ b/router/internal/expr/uses_request_operation_sha256_test.go @@ -0,0 +1,122 @@ +package expr + +import ( + "testing" + + "github.com/expr-lang/expr/ast" + "github.com/stretchr/testify/assert" +) + +func TestUsesRequestOperationSha256(t *testing.T) { + t.Parallel() + + t.Run("nil node", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + visitor.Visit(nil) + assert.False(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("request.operation.sha256Hash access", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + node := ast.Node(&ast.MemberNode{ + Node: &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "request"}, + Property: &ast.StringNode{Value: "operation"}, + }, + Property: &ast.StringNode{Value: "sha256Hash"}, + }) + visitor.Visit(&node) + assert.True(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("request[\"operation\"][\"sha256Hash\"] access", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + node := ast.Node(&ast.MemberNode{ + Node: &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "request"}, + Property: &ast.IdentifierNode{Value: "operation"}, + }, + Property: &ast.IdentifierNode{Value: "sha256Hash"}, + }) + visitor.Visit(&node) + assert.True(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("request[\"operation\"].sha256Hash access", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + node := ast.Node(&ast.MemberNode{ + Node: &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "request"}, + Property: &ast.StringNode{Value: "operation"}, + }, + Property: &ast.IdentifierNode{Value: "sha256Hash"}, + }) + visitor.Visit(&node) + assert.True(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("request.operation.hash access - not sha256", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + node := ast.Node(&ast.MemberNode{ + Node: &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "request"}, + Property: &ast.StringNode{Value: "operation"}, + }, + Property: &ast.StringNode{Value: "hash"}, + }) + visitor.Visit(&node) + assert.False(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("other.operation.sha256Hash access - wrong root", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + node := ast.Node(&ast.MemberNode{ + Node: &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "other"}, + Property: &ast.StringNode{Value: "operation"}, + }, + Property: &ast.StringNode{Value: "sha256Hash"}, + }) + visitor.Visit(&node) + assert.False(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("request.body.sha256Hash access - wrong middle", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{} + node := ast.Node(&ast.MemberNode{ + Node: &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "request"}, + Property: &ast.StringNode{Value: "body"}, + }, + Property: &ast.StringNode{Value: "sha256Hash"}, + }) + visitor.Visit(&node) + assert.False(t, visitor.UsesRequestOperationSha256) + }) + + t.Run("already set short-circuit", func(t *testing.T) { + t.Parallel() + + visitor := &UsesRequestOperationSha256{UsesRequestOperationSha256: true} + node := ast.Node(&ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "request"}, + Property: &ast.StringNode{Value: "anything"}, + }) + visitor.Visit(&node) + assert.True(t, visitor.UsesRequestOperationSha256) + }) +} diff --git a/router/internal/expr/visitor_group.go b/router/internal/expr/visitor_group.go index 539846c3ad..ab18f88f58 100644 --- a/router/internal/expr/visitor_group.go +++ b/router/internal/expr/visitor_group.go @@ -9,6 +9,7 @@ const ( usesSubgraphTraceKey usesResponseBodyKey usesSubgraphResponseBodyKey + usesRequestOperationSha256Key ) // VisitorGroup is a struct that holds all the VisitorManager that are used to compile the expressions @@ -22,10 +23,11 @@ type VisitorGroup struct { func createVisitorGroup() *VisitorGroup { return &VisitorGroup{ globalVisitors: map[visitorKind]ast.Visitor{ - usesRequestBodyKey: &UsesBody{}, - usesSubgraphTraceKey: &UsesSubgraphTrace{}, - usesResponseBodyKey: &UsesResponseBody{}, - usesSubgraphResponseBodyKey: &UsesSubgraphResponseBody{}, + usesRequestBodyKey: &UsesBody{}, + usesSubgraphTraceKey: &UsesSubgraphTrace{}, + usesResponseBodyKey: &UsesResponseBody{}, + usesSubgraphResponseBodyKey: &UsesSubgraphResponseBody{}, + usesRequestOperationSha256Key: &UsesRequestOperationSha256{}, }, } } @@ -61,3 +63,11 @@ func (c *VisitorGroup) IsSubgraphResponseBodyUsedInExpressions() bool { body := c.globalVisitors[usesSubgraphResponseBodyKey].(*UsesSubgraphResponseBody) return body.UsesSubgraphResponseBody } + +func (c *VisitorGroup) IsRequestOperationSha256UsedInExpressions() bool { + if c == nil { + return true + } + v := c.globalVisitors[usesRequestOperationSha256Key].(*UsesRequestOperationSha256) + return v.UsesRequestOperationSha256 +} diff --git a/router/internal/expr/visitor_group_test.go b/router/internal/expr/visitor_group_test.go index 30e621b0ab..7ec2dd48f2 100644 --- a/router/internal/expr/visitor_group_test.go +++ b/router/internal/expr/visitor_group_test.go @@ -64,6 +64,51 @@ func TestVisitorManager(t *testing.T) { } }) + t.Run("verify IsRequestOperationSha256UsedInExpressions", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + expression string + expectedResult bool + }{ + { + name: "without sha256", + expression: "request.operation.hash", + expectedResult: false, + }, + { + name: "with sha256 dot chaining", + expression: "request.operation.sha256Hash", + expectedResult: true, + }, + { + name: "with sha256 square bracket", + expression: `request["operation"]["sha256Hash"]`, + expectedResult: true, + }, + { + name: "with sha256 mixed access", + expression: `request["operation"].sha256Hash`, + expectedResult: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + exprManager := CreateNewExprManager() + visitorManager := exprManager.VisitorManager + + _, err := exprManager.CompileAnyExpression(tc.expression) + require.NoError(t, err) + + require.Equal(t, tc.expectedResult, visitorManager.IsRequestOperationSha256UsedInExpressions()) + }) + } + }) + t.Run("verify IsSubgraphResponseBodyUsedInExpressions", func(t *testing.T) { t.Parallel()