diff --git a/execution/graphql/complexity.go b/execution/graphql/complexity.go deleted file mode 100644 index dbecd687a2..0000000000 --- a/execution/graphql/complexity.go +++ /dev/null @@ -1,76 +0,0 @@ -package graphql - -import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/graphqlerrors" - "github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity" - "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" -) - -var DefaultComplexityCalculator = defaultComplexityCalculator{} - -type ComplexityCalculator interface { - Calculate(operation, definition *ast.Document) (ComplexityResult, error) -} - -type defaultComplexityCalculator struct { -} - -func (d defaultComplexityCalculator) Calculate(operation, definition *ast.Document) (ComplexityResult, error) { - report := operationreport.Report{} - globalComplexityResult, fieldsComplexityResult := operation_complexity.CalculateOperationComplexity(operation, definition, &report) - - return complexityResult(globalComplexityResult, fieldsComplexityResult, report) -} - -type ComplexityResult struct { - NodeCount int - Complexity int - Depth int - PerRootField []FieldComplexityResult - Errors graphqlerrors.Errors -} - -type FieldComplexityResult struct { - TypeName string - FieldName string - Alias string - NodeCount int - Complexity int - Depth int -} - -func complexityResult(globalComplexityResult operation_complexity.OperationStats, fieldsComplexityResult []operation_complexity.RootFieldStats, report operationreport.Report) (ComplexityResult, error) { - allFieldComplexityResults := make([]FieldComplexityResult, 0, len(fieldsComplexityResult)) - for _, fieldResult := range fieldsComplexityResult { - allFieldComplexityResults = append(allFieldComplexityResults, FieldComplexityResult{ - TypeName: fieldResult.TypeName, - FieldName: fieldResult.FieldName, - Alias: fieldResult.Alias, - NodeCount: fieldResult.Stats.NodeCount, - Complexity: fieldResult.Stats.Complexity, - Depth: fieldResult.Stats.Depth, - }) - } - - result := ComplexityResult{ - NodeCount: globalComplexityResult.NodeCount, - Complexity: globalComplexityResult.Complexity, - Depth: globalComplexityResult.Depth, - PerRootField: allFieldComplexityResults, - Errors: nil, - } - - if !report.HasErrors() { - return result, nil - } - - result.Errors = graphqlerrors.RequestErrorsFromOperationReport(report) - - var err error - if len(report.InternalErrors) > 0 { - err = report.InternalErrors[0] - } - - return result, err -} diff --git a/execution/graphql/request.go b/execution/graphql/request.go index 1e10748819..a3ab0888d0 100644 --- a/execution/graphql/request.go +++ b/execution/graphql/request.go @@ -9,7 +9,6 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) @@ -67,23 +66,6 @@ func (r *Request) SetHeader(header http.Header) { r.request.Header = header } -func (r *Request) CalculateComplexity(complexityCalculator ComplexityCalculator, schema *Schema) (ComplexityResult, error) { - if schema == nil { - return ComplexityResult{}, ErrNilSchema - } - - report := r.parseQueryOnce() - if report.HasErrors() { - return complexityResult( - operation_complexity.OperationStats{}, - []operation_complexity.RootFieldStats{}, - report, - ) - } - - return complexityCalculator.Calculate(&r.document, &schema.document) -} - func (r *Request) Document() *ast.Document { return &r.document } diff --git a/execution/graphql/request_test.go b/execution/graphql/request_test.go index a39160dfbf..d59f1d72a5 100644 --- a/execution/graphql/request_test.go +++ b/execution/graphql/request_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity" "github.com/wundergraph/graphql-go-tools/v2/pkg/starwars" ) @@ -88,64 +89,69 @@ func TestRequest_parseQueryOnce(t *testing.T) { } func TestRequest_CalculateComplexity(t *testing.T) { - t.Run("should return error when schema is nil", func(t *testing.T) { - request := Request{} - result, err := request.CalculateComplexity(DefaultComplexityCalculator, nil) - assert.Error(t, err) - assert.Equal(t, ErrNilSchema, err) - assert.Equal(t, 0, result.NodeCount, "unexpected node count") - assert.Equal(t, 0, result.Complexity, "unexpected complexity") - assert.Equal(t, 0, result.Depth, "unexpected depth") - assert.Nil(t, result.PerRootField, "per root field results is not nil") - }) - t.Run("should successfully calculate the complexity of request", func(t *testing.T) { schema := StarwarsSchema(t) - request := StarwarsRequestForQuery(t, starwars.FileSimpleHeroQuery) - result, err := request.CalculateComplexity(DefaultComplexityCalculator, schema) - assert.NoError(t, err) - assert.Equal(t, 1, result.NodeCount, "unexpected node count") - assert.Equal(t, 1, result.Complexity, "unexpected complexity") - assert.Equal(t, 2, result.Depth, "unexpected depth") - assert.Equal(t, []FieldComplexityResult{ + + report := request.parseQueryOnce() + assert.False(t, report.HasErrors()) + + estimator := operation_complexity.NewOperationComplexityEstimator(false) + global, rootFields := estimator.Do(request.Document(), schema.Document(), &report) + assert.False(t, report.HasErrors()) + + assert.Equal(t, 1, global.NodeCount, "unexpected node count") + assert.Equal(t, 1, global.Complexity, "unexpected complexity") + assert.Equal(t, 2, global.Depth, "unexpected depth") + assert.Equal(t, []operation_complexity.RootFieldStats{ { - TypeName: "Query", - FieldName: "hero", - Alias: "", - NodeCount: 1, - Complexity: 1, - Depth: 1, + TypeName: "Query", + FieldName: "hero", + Alias: "", + Stats: operation_complexity.OperationStats{ + NodeCount: 1, + Complexity: 1, + Depth: 1, + }, }, - }, result.PerRootField, "unexpected per root field results") + }, rootFields, "unexpected per root field results") }) t.Run("should successfully calculate the complexity of request with multiple query fields", func(t *testing.T) { schema := StarwarsSchema(t) - request := StarwarsRequestForQuery(t, starwars.FileHeroWithAliasesQuery) - result, err := request.CalculateComplexity(DefaultComplexityCalculator, schema) - assert.NoError(t, err) - assert.Equal(t, 2, result.NodeCount, "unexpected node count") - assert.Equal(t, 2, result.Complexity, "unexpected complexity") - assert.Equal(t, 2, result.Depth, "unexpected depth") - assert.Equal(t, []FieldComplexityResult{ + + report := request.parseQueryOnce() + assert.False(t, report.HasErrors()) + + estimator := operation_complexity.NewOperationComplexityEstimator(false) + global, rootFields := estimator.Do(request.Document(), schema.Document(), &report) + assert.False(t, report.HasErrors()) + + assert.Equal(t, 2, global.NodeCount, "unexpected node count") + assert.Equal(t, 2, global.Complexity, "unexpected complexity") + assert.Equal(t, 2, global.Depth, "unexpected depth") + assert.Equal(t, []operation_complexity.RootFieldStats{ { - TypeName: "Query", - FieldName: "hero", - Alias: "empireHero", - NodeCount: 1, - Complexity: 1, - Depth: 1, + TypeName: "Query", + FieldName: "hero", + Alias: "empireHero", + Stats: operation_complexity.OperationStats{ + NodeCount: 1, + Complexity: 1, + Depth: 1, + }, }, { - TypeName: "Query", - FieldName: "hero", - Alias: "jediHero", - NodeCount: 1, - Complexity: 1, - Depth: 1, - }}, result.PerRootField, "unexpected per root field results") + TypeName: "Query", + FieldName: "hero", + Alias: "jediHero", + Stats: operation_complexity.OperationStats{ + NodeCount: 1, + Complexity: 1, + Depth: 1, + }, + }}, rootFields, "unexpected per root field results") }) } diff --git a/v2/pkg/middleware/operation_complexity/operation_complexity.go b/v2/pkg/middleware/operation_complexity/operation_complexity.go index dba06c7fc8..6357edbfdd 100644 --- a/v2/pkg/middleware/operation_complexity/operation_complexity.go +++ b/v2/pkg/middleware/operation_complexity/operation_complexity.go @@ -52,9 +52,8 @@ var ( ) const ( - skipIntrospection = true - __schemaLiteral = "__schema" - __typeLiteral = "__type" + __schemaLiteral = "__schema" + __typeLiteral = "__type" ) type OperationComplexityEstimator struct { @@ -62,12 +61,12 @@ type OperationComplexityEstimator struct { visitor *complexityVisitor } -func NewOperationComplexityEstimator() *OperationComplexityEstimator { - +func NewOperationComplexityEstimator(skipIntrospection bool) *OperationComplexityEstimator { walker := astvisitor.NewWalker(48) visitor := &complexityVisitor{ - Walker: &walker, - multipliers: make([]multiplier, 0, 16), + Walker: &walker, + multipliers: make([]multiplier, 0, 16), + skipIntrospection: skipIntrospection, } walker.RegisterEnterDocumentVisitor(visitor) @@ -116,8 +115,9 @@ func (n *OperationComplexityEstimator) Do(operation, definition *ast.Document, r return globalResult, n.visitor.calculatedRootFieldStats } +// Deprecated: use NewOperationComplexityEstimator. func CalculateOperationComplexity(operation, definition *ast.Document, report *operationreport.Report) (OperationStats, []RootFieldStats) { - estimator := NewOperationComplexityEstimator() + estimator := NewOperationComplexityEstimator(false) return estimator.Do(operation, definition, report) } @@ -141,6 +141,9 @@ type complexityVisitor struct { currentRootFieldSelectionSetDepth int calculatedRootFieldStats []RootFieldStats + + // Enforces to ignore introspection queries in calculations. + skipIntrospection bool } type multiplier struct { @@ -202,7 +205,7 @@ func (c *complexityVisitor) EnterField(ref int) { } typeName, fieldName, alias := c.extractFieldRelatedNames(ref, definition) - if skipIntrospection && (fieldName == __schemaLiteral || fieldName == __typeLiteral) { + if c.skipIntrospection && (fieldName == __schemaLiteral || fieldName == __typeLiteral) { c.SkipNode() return } diff --git a/v2/pkg/middleware/operation_complexity/operation_complexity_test.go b/v2/pkg/middleware/operation_complexity/operation_complexity_test.go index 14d316aa14..9850655594 100644 --- a/v2/pkg/middleware/operation_complexity/operation_complexity_test.go +++ b/v2/pkg/middleware/operation_complexity/operation_complexity_test.go @@ -467,6 +467,26 @@ func TestCalculateOperationComplexity(t *testing.T) { }) t.Run("introspection query", func(t *testing.T) { run(t, testDefinition, introspectionQuery, + OperationStats{ + NodeCount: 59, + Complexity: 59, + Depth: 13, + }, + []RootFieldStats{ + { + TypeName: "Query", + FieldName: "__schema", + Stats: OperationStats{ + NodeCount: 59, + Complexity: 59, + Depth: 12, + }, + }, + }, + ) + }) + t.Run("introspection query with skip", func(t *testing.T) { + runSkipIntrospection(t, testDefinition, introspectionQuery, OperationStats{ NodeCount: 0, Complexity: 0, @@ -477,17 +497,16 @@ func TestCalculateOperationComplexity(t *testing.T) { }) } -var run = func(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) { +func runConfig(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats, skipIntrospection bool) { def := unsafeparser.ParseGraphqlDocumentString(definition) op := unsafeparser.ParseGraphqlDocumentString(operation) report := operationreport.Report{} astnormalization.NormalizeOperation(&op, &def, &report) - actualGlobalComplexityResult, actualFieldsComplexityResult := CalculateOperationComplexity(&op, &def, &report) - if report.HasErrors() { - require.NoError(t, report) - } + estimator := NewOperationComplexityEstimator(skipIntrospection) + actualGlobalComplexityResult, actualFieldsComplexityResult := estimator.Do(&op, &def, &report) + require.False(t, report.HasErrors()) assert.Equal(t, expectedGlobalComplexityResult.NodeCount, actualGlobalComplexityResult.NodeCount, "unexpected global node count") assert.Equal(t, expectedGlobalComplexityResult.Complexity, actualGlobalComplexityResult.Complexity, "unexpected global complexity") @@ -495,17 +514,26 @@ var run = func(t *testing.T, definition, operation string, expectedGlobalComplex assert.Equal(t, expectedFieldsComplexityResult, actualFieldsComplexityResult, "unexpected fields complexity result") } +func run(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) { + runConfig(t, definition, operation, expectedGlobalComplexityResult, expectedFieldsComplexityResult, false) +} + +func runSkipIntrospection(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) { + runConfig(t, definition, operation, expectedGlobalComplexityResult, expectedFieldsComplexityResult, true) +} + func BenchmarkEstimateComplexity(b *testing.B) { def := unsafeparser.ParseGraphqlDocumentString(testDefinition) op := unsafeparser.ParseGraphqlDocumentString(complexQuery) - estimator := NewOperationComplexityEstimator() - report := operationreport.Report{} - b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { + // We use NewOperationComplexityEstimator for every operation in production, thus + // we want it in the benchmarking loop. + estimator := NewOperationComplexityEstimator(false) + report := operationreport.Report{} globalComplexityResult, _ := estimator.Do(&op, &def, &report) if report.HasErrors() { b.Fatal(report)