Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 0 additions & 76 deletions execution/graphql/complexity.go

This file was deleted.

18 changes: 0 additions & 18 deletions execution/graphql/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
"github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity"
"github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport"
)

Expand Down Expand Up @@ -67,23 +66,6 @@ func (r *Request) SetHeader(header http.Header) {
r.request.Header = header
}

func (r *Request) CalculateComplexity(complexityCalculator ComplexityCalculator, schema *Schema) (ComplexityResult, error) {
if schema == nil {
return ComplexityResult{}, ErrNilSchema
}

report := r.parseQueryOnce()
if report.HasErrors() {
return complexityResult(
operation_complexity.OperationStats{},
[]operation_complexity.RootFieldStats{},
report,
)
}

return complexityCalculator.Calculate(&r.document, &schema.document)
}

func (r *Request) Document() *ast.Document {
return &r.document
}
Expand Down
96 changes: 51 additions & 45 deletions execution/graphql/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/stretchr/testify/assert"

"github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity"
"github.com/wundergraph/graphql-go-tools/v2/pkg/starwars"
)

Expand Down Expand Up @@ -88,64 +89,69 @@ func TestRequest_parseQueryOnce(t *testing.T) {
}

func TestRequest_CalculateComplexity(t *testing.T) {
t.Run("should return error when schema is nil", func(t *testing.T) {
request := Request{}
result, err := request.CalculateComplexity(DefaultComplexityCalculator, nil)
assert.Error(t, err)
assert.Equal(t, ErrNilSchema, err)
assert.Equal(t, 0, result.NodeCount, "unexpected node count")
assert.Equal(t, 0, result.Complexity, "unexpected complexity")
assert.Equal(t, 0, result.Depth, "unexpected depth")
assert.Nil(t, result.PerRootField, "per root field results is not nil")
})

t.Run("should successfully calculate the complexity of request", func(t *testing.T) {
schema := StarwarsSchema(t)

request := StarwarsRequestForQuery(t, starwars.FileSimpleHeroQuery)
result, err := request.CalculateComplexity(DefaultComplexityCalculator, schema)
assert.NoError(t, err)
assert.Equal(t, 1, result.NodeCount, "unexpected node count")
assert.Equal(t, 1, result.Complexity, "unexpected complexity")
assert.Equal(t, 2, result.Depth, "unexpected depth")
assert.Equal(t, []FieldComplexityResult{

report := request.parseQueryOnce()
assert.False(t, report.HasErrors())

estimator := operation_complexity.NewOperationComplexityEstimator(false)
global, rootFields := estimator.Do(request.Document(), schema.Document(), &report)
assert.False(t, report.HasErrors())

assert.Equal(t, 1, global.NodeCount, "unexpected node count")
assert.Equal(t, 1, global.Complexity, "unexpected complexity")
assert.Equal(t, 2, global.Depth, "unexpected depth")
assert.Equal(t, []operation_complexity.RootFieldStats{
{
TypeName: "Query",
FieldName: "hero",
Alias: "",
NodeCount: 1,
Complexity: 1,
Depth: 1,
TypeName: "Query",
FieldName: "hero",
Alias: "",
Stats: operation_complexity.OperationStats{
NodeCount: 1,
Complexity: 1,
Depth: 1,
},
},
}, result.PerRootField, "unexpected per root field results")
}, rootFields, "unexpected per root field results")
})

t.Run("should successfully calculate the complexity of request with multiple query fields", func(t *testing.T) {
schema := StarwarsSchema(t)

request := StarwarsRequestForQuery(t, starwars.FileHeroWithAliasesQuery)
result, err := request.CalculateComplexity(DefaultComplexityCalculator, schema)
assert.NoError(t, err)
assert.Equal(t, 2, result.NodeCount, "unexpected node count")
assert.Equal(t, 2, result.Complexity, "unexpected complexity")
assert.Equal(t, 2, result.Depth, "unexpected depth")
assert.Equal(t, []FieldComplexityResult{

report := request.parseQueryOnce()
assert.False(t, report.HasErrors())

estimator := operation_complexity.NewOperationComplexityEstimator(false)
global, rootFields := estimator.Do(request.Document(), schema.Document(), &report)
assert.False(t, report.HasErrors())

assert.Equal(t, 2, global.NodeCount, "unexpected node count")
assert.Equal(t, 2, global.Complexity, "unexpected complexity")
assert.Equal(t, 2, global.Depth, "unexpected depth")
assert.Equal(t, []operation_complexity.RootFieldStats{
{
TypeName: "Query",
FieldName: "hero",
Alias: "empireHero",
NodeCount: 1,
Complexity: 1,
Depth: 1,
TypeName: "Query",
FieldName: "hero",
Alias: "empireHero",
Stats: operation_complexity.OperationStats{
NodeCount: 1,
Complexity: 1,
Depth: 1,
},
},
{
TypeName: "Query",
FieldName: "hero",
Alias: "jediHero",
NodeCount: 1,
Complexity: 1,
Depth: 1,
}}, result.PerRootField, "unexpected per root field results")
TypeName: "Query",
FieldName: "hero",
Alias: "jediHero",
Stats: operation_complexity.OperationStats{
NodeCount: 1,
Complexity: 1,
Depth: 1,
},
}}, rootFields, "unexpected per root field results")
})
}

Expand Down
21 changes: 12 additions & 9 deletions v2/pkg/middleware/operation_complexity/operation_complexity.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,21 @@ var (
)

const (
skipIntrospection = true
__schemaLiteral = "__schema"
__typeLiteral = "__type"
__schemaLiteral = "__schema"
__typeLiteral = "__type"
)

type OperationComplexityEstimator struct {
walker *astvisitor.Walker
visitor *complexityVisitor
}

func NewOperationComplexityEstimator() *OperationComplexityEstimator {

func NewOperationComplexityEstimator(skipIntrospection bool) *OperationComplexityEstimator {
walker := astvisitor.NewWalker(48)
visitor := &complexityVisitor{
Walker: &walker,
multipliers: make([]multiplier, 0, 16),
Walker: &walker,
multipliers: make([]multiplier, 0, 16),
skipIntrospection: skipIntrospection,
}

walker.RegisterEnterDocumentVisitor(visitor)
Expand Down Expand Up @@ -116,8 +115,9 @@ func (n *OperationComplexityEstimator) Do(operation, definition *ast.Document, r
return globalResult, n.visitor.calculatedRootFieldStats
}

// Deprecated: use NewOperationComplexityEstimator.
func CalculateOperationComplexity(operation, definition *ast.Document, report *operationreport.Report) (OperationStats, []RootFieldStats) {
estimator := NewOperationComplexityEstimator()
estimator := NewOperationComplexityEstimator(false)
return estimator.Do(operation, definition, report)
}

Expand All @@ -141,6 +141,9 @@ type complexityVisitor struct {
currentRootFieldSelectionSetDepth int

calculatedRootFieldStats []RootFieldStats

// Enforces to ignore introspection queries in calculations.
skipIntrospection bool
}

type multiplier struct {
Expand Down Expand Up @@ -202,7 +205,7 @@ func (c *complexityVisitor) EnterField(ref int) {
}

typeName, fieldName, alias := c.extractFieldRelatedNames(ref, definition)
Comment thread
ysmolski marked this conversation as resolved.
if skipIntrospection && (fieldName == __schemaLiteral || fieldName == __typeLiteral) {
if c.skipIntrospection && (fieldName == __schemaLiteral || fieldName == __typeLiteral) {
c.SkipNode()
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,26 @@ func TestCalculateOperationComplexity(t *testing.T) {
})
t.Run("introspection query", func(t *testing.T) {
run(t, testDefinition, introspectionQuery,
OperationStats{
NodeCount: 59,
Complexity: 59,
Depth: 13,
},
[]RootFieldStats{
{
TypeName: "Query",
FieldName: "__schema",
Stats: OperationStats{
NodeCount: 59,
Complexity: 59,
Depth: 12,
},
},
},
)
})
t.Run("introspection query with skip", func(t *testing.T) {
runSkipIntrospection(t, testDefinition, introspectionQuery,
OperationStats{
NodeCount: 0,
Complexity: 0,
Expand All @@ -477,35 +497,43 @@ func TestCalculateOperationComplexity(t *testing.T) {
})
}

var run = func(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) {
func runConfig(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats, skipIntrospection bool) {
def := unsafeparser.ParseGraphqlDocumentString(definition)
op := unsafeparser.ParseGraphqlDocumentString(operation)
report := operationreport.Report{}

astnormalization.NormalizeOperation(&op, &def, &report)

actualGlobalComplexityResult, actualFieldsComplexityResult := CalculateOperationComplexity(&op, &def, &report)
if report.HasErrors() {
require.NoError(t, report)
}
estimator := NewOperationComplexityEstimator(skipIntrospection)
actualGlobalComplexityResult, actualFieldsComplexityResult := estimator.Do(&op, &def, &report)
require.False(t, report.HasErrors())

assert.Equal(t, expectedGlobalComplexityResult.NodeCount, actualGlobalComplexityResult.NodeCount, "unexpected global node count")
assert.Equal(t, expectedGlobalComplexityResult.Complexity, actualGlobalComplexityResult.Complexity, "unexpected global complexity")
assert.Equal(t, expectedGlobalComplexityResult.Depth, actualGlobalComplexityResult.Depth, "unexpected global depth")
assert.Equal(t, expectedFieldsComplexityResult, actualFieldsComplexityResult, "unexpected fields complexity result")
}

func run(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) {
runConfig(t, definition, operation, expectedGlobalComplexityResult, expectedFieldsComplexityResult, false)
}

func runSkipIntrospection(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) {
runConfig(t, definition, operation, expectedGlobalComplexityResult, expectedFieldsComplexityResult, true)
}

func BenchmarkEstimateComplexity(b *testing.B) {
def := unsafeparser.ParseGraphqlDocumentString(testDefinition)
op := unsafeparser.ParseGraphqlDocumentString(complexQuery)

estimator := NewOperationComplexityEstimator()
report := operationreport.Report{}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
// We use NewOperationComplexityEstimator for every operation in production, thus
// we want it in the benchmarking loop.
estimator := NewOperationComplexityEstimator(false)
report := operationreport.Report{}
globalComplexityResult, _ := estimator.Do(&op, &def, &report)
if report.HasErrors() {
b.Fatal(report)
Expand Down