diff --git a/v2/pkg/astnormalization/astnormalization.go b/v2/pkg/astnormalization/astnormalization.go index f9ba253999..60ff1f2d03 100644 --- a/v2/pkg/astnormalization/astnormalization.go +++ b/v2/pkg/astnormalization/astnormalization.go @@ -208,6 +208,7 @@ func (o *OperationNormalizer) setupOperationWalkers() { } directivesIncludeSkip := astvisitor.NewWalker(8) + preventFragmentCycles(&directivesIncludeSkip) directiveIncludeSkip(&directivesIncludeSkip) cleanup := astvisitor.NewWalker(8) @@ -392,3 +393,99 @@ func (v *VariablesNormalizer) NormalizeOperation(operation, definition *ast.Docu return v.variablesExtractionVisitor.uploadsPath } + +type fragmentCycleVisitor struct { + *astvisitor.Walker + operation, definition *ast.Document + currentFragmentRef int // current fragment ref + spreadsInFragments map[int][]int // fragment ref -> spread refs +} + +func (f *fragmentCycleVisitor) LeaveDocument(operation, _ *ast.Document) { + report := f.Walker.Report + if report == nil { + return + } + + visited := make(map[string]bool) + stack := make(map[string]bool) + + for fragmentIdx := range f.spreadsInFragments { + f.detectFragmentCycle(fragmentIdx, []int{fragmentIdx}, visited, stack, operation) + } +} + +func (f *fragmentCycleVisitor) detectFragmentCycle(fragmentIdx int, path []int, visited, stack map[string]bool, operation *ast.Document) bool { + fragName := string(operation.FragmentDefinitionNameBytes(fragmentIdx)) + if stack[fragName] { + // Cycle detected, report using the spread that closes the cycle + cycleStart := 0 + for i, idx := range path { + if string(operation.FragmentDefinitionNameBytes(idx)) == fragName { + cycleStart = i + break + } + } + cyclePath := path[cycleStart:] + if len(cyclePath) > 0 { + // The spread that closes the cycle is the first spread in the cycle + cycleFragIdx := cyclePath[0] + spreadName := operation.FragmentDefinitionNameBytes(cycleFragIdx) + f.Walker.Report.AddExternalError(operationreport.ErrFragmentSpreadFormsCycle(spreadName)) + } + return true + } + if visited[fragName] { + return false + } + visited[fragName] = true + stack[fragName] = true + for _, spreadRef := range f.spreadsInFragments[fragmentIdx] { + // Find the fragment definition index for this spread name + fragName := operation.FragmentSpreadNameBytes(spreadRef) + fragRef, exists := operation.FragmentDefinitionRef(fragName) + if exists && f.detectFragmentCycle(fragRef, append(path, fragRef), visited, stack, operation) { + return true + } + } + stack[fragName] = false + return false +} + +func (f *fragmentCycleVisitor) EnterDocument(operation, definition *ast.Document) { + f.operation = operation + f.definition = definition + f.currentFragmentRef = -1 + f.spreadsInFragments = make(map[int][]int) +} + +func (f *fragmentCycleVisitor) LeaveFragmentDefinition(ref int) { + f.currentFragmentRef = -1 +} + +func (f *fragmentCycleVisitor) EnterFragmentDefinition(ref int) { + f.currentFragmentRef = ref +} + +func (f *fragmentCycleVisitor) EnterFragmentSpread(ref int) { + if f.currentFragmentRef == -1 { + return + } + if _, exists := f.spreadsInFragments[f.currentFragmentRef]; !exists { + f.spreadsInFragments[f.currentFragmentRef] = []int{ref} + return + } + f.spreadsInFragments[f.currentFragmentRef] = append(f.spreadsInFragments[f.currentFragmentRef], ref) +} + +func preventFragmentCycles(walker *astvisitor.Walker) *fragmentCycleVisitor { + visitor := &fragmentCycleVisitor{ + Walker: walker, + operation: nil, + definition: nil, + } + walker.RegisterDocumentVisitor(visitor) + walker.RegisterEnterFragmentSpreadVisitor(visitor) + walker.RegisterFragmentDefinitionVisitor(visitor) + return visitor +} diff --git a/v2/pkg/astparser/parser_test.go b/v2/pkg/astparser/parser_test.go index bc25ef587d..2c6d87235e 100644 --- a/v2/pkg/astparser/parser_test.go +++ b/v2/pkg/astparser/parser_test.go @@ -2368,8 +2368,7 @@ this is a schema \ panic("want schema description to be defined") } description := doc.Input.ByteSliceString(schema.Description.Content) - expectedDescription := `this is a schema \` - require.Equal(t, expectedDescription, description) + require.Equal(t, `this is a schema \ `, description) query := doc.RootOperationTypeDefinitions[schema.RootOperationTypeDefinitions.Refs[0]] if query.OperationType != ast.OperationTypeQuery { panic("want OperationTypeQuery") @@ -2436,8 +2435,7 @@ this is a schema \ panic("want schema description to be defined") } description := doc.Input.ByteSliceString(schema.Description.Content) - expectedDescription := `this is a schema \` - require.Equal(t, expectedDescription, description) + require.Equal(t, `this is a schema \ `, description) query := doc.RootOperationTypeDefinitions[schema.RootOperationTypeDefinitions.Refs[0]] if query.OperationType != ast.OperationTypeQuery { panic("want OperationTypeQuery") @@ -2504,9 +2502,7 @@ this is a schema \ if name.DefaultValue.Value.Kind != ast.ValueKindString { panic("want ValueKindString") } - if doc.Input.ByteSliceString(doc.StringValues[name.DefaultValue.Value.Ref].Content) != `Gopher \` { - panic("want Gopher") - } + assert.Equal(t, doc.Input.ByteSliceString(doc.StringValues[name.DefaultValue.Value.Ref].Content), `Gopher \ `) }) }) @@ -2530,9 +2526,7 @@ this is a schema \ if name.DefaultValue.Value.Kind != ast.ValueKindString { panic("want ValueKindString") } - if doc.Input.ByteSliceString(doc.StringValues[name.DefaultValue.Value.Ref].Content) != `Gopher \\` { - panic("want Gopher") - } + assert.Equal(t, doc.Input.ByteSliceString(doc.StringValues[name.DefaultValue.Value.Ref].Content), `Gopher \\ `) }) }) }) @@ -2580,6 +2574,18 @@ func TestErrorReport(t *testing.T) { t.Fatalf("want:\n%s\ngot:\n%s\n", want, report.Error()) } }) + t.Run("ident incomplete block string", func(t *testing.T) { + _, report := ParseGraphqlDocumentString(`union"""`) + + if !report.HasErrors() { + t.Fatalf("want err, got nil") + } + + want := "external: unexpected token - got: BLOCKSTRING want one of: [IDENT], locations: [{Line:1 Column:6}], path: []" + if report.Error() != want { + t.Fatalf("want:\n%s\ngot:\n%s\n", want, report.Error()) + } + }) } func TestParseStarwars(t *testing.T) { diff --git a/v2/pkg/astparser/testdata/starwars.schema.graphql b/v2/pkg/astparser/testdata/starwars.schema.graphql index c6c665e03a..14b699d476 100644 --- a/v2/pkg/astparser/testdata/starwars.schema.graphql +++ b/v2/pkg/astparser/testdata/starwars.schema.graphql @@ -172,7 +172,7 @@ scalar Boolean scalar ID "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included when true." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true." diff --git a/v2/pkg/astparser/testdata/todo.graphql b/v2/pkg/astparser/testdata/todo.graphql index 5174a32e8c..0ae6800cd2 100644 --- a/v2/pkg/astparser/testdata/todo.graphql +++ b/v2/pkg/astparser/testdata/todo.graphql @@ -56,7 +56,7 @@ scalar Boolean scalar ID "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included when true." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true." diff --git a/v2/pkg/astprinter/fixtures/starwars_schema_definition.golden b/v2/pkg/astprinter/fixtures/starwars_schema_definition.golden index 25ae066630..9d75e2dbc3 100644 --- a/v2/pkg/astprinter/fixtures/starwars_schema_definition.golden +++ b/v2/pkg/astprinter/fixtures/starwars_schema_definition.golden @@ -210,7 +210,7 @@ scalar ID "Directs the executor to include this field or fragment only when the argument is true." directive @include( - "Included whentrue." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT diff --git a/v2/pkg/astprinter/testdata/starwars.schema.graphql b/v2/pkg/astprinter/testdata/starwars.schema.graphql index 2661f079a8..172c67293c 100644 --- a/v2/pkg/astprinter/testdata/starwars.schema.graphql +++ b/v2/pkg/astprinter/testdata/starwars.schema.graphql @@ -204,7 +204,7 @@ scalar Boolean scalar ID "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included whentrue." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true." diff --git a/v2/pkg/asttransform/baseschema.go b/v2/pkg/asttransform/baseschema.go index 0567e58bbf..45a6ebc94f 100644 --- a/v2/pkg/asttransform/baseschema.go +++ b/v2/pkg/asttransform/baseschema.go @@ -148,7 +148,7 @@ scalar Boolean scalar ID "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included when true." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true." diff --git a/v2/pkg/astvalidation/operation_validation_test.go b/v2/pkg/astvalidation/operation_validation_test.go index ff0f4c796b..76f12a1957 100644 --- a/v2/pkg/astvalidation/operation_validation_test.go +++ b/v2/pkg/astvalidation/operation_validation_test.go @@ -5075,7 +5075,7 @@ scalar Boolean scalar ID @custom(typeName: "string") "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included when true." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true." diff --git a/v2/pkg/introspection/testdata/starwars.schema.graphql b/v2/pkg/introspection/testdata/starwars.schema.graphql index cc7d84112e..e777756cad 100644 --- a/v2/pkg/introspection/testdata/starwars.schema.graphql +++ b/v2/pkg/introspection/testdata/starwars.schema.graphql @@ -178,7 +178,7 @@ scalar Boolean scalar ID "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included when true." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true." diff --git a/v2/pkg/lexer/lexer.go b/v2/pkg/lexer/lexer.go index ca282256f3..ab37df3747 100644 --- a/v2/pkg/lexer/lexer.go +++ b/v2/pkg/lexer/lexer.go @@ -352,6 +352,9 @@ func (l *Lexer) readBlockString(tok *token.Token) { quoteCount = 0 whitespaceCount++ case runes.EOF: + tok.SetEnd(l.input.InputPosition, l.input.TextPosition) + tok.Literal.Start += uint32(leadingWhitespaceToken) + tok.Literal.End -= uint32(whitespaceCount) return case runes.QUOTE: if escaped { @@ -385,27 +388,20 @@ func (l *Lexer) readBlockString(tok *token.Token) { } func (l *Lexer) readSingleLineString(tok *token.Token) { - tok.Keyword = keyword.STRING tok.SetStart(l.input.InputPosition, l.input.TextPosition) tok.TextPosition.CharStart -= 1 escaped := false - whitespaceCount := 0 - reachedFirstNonWhitespace := false - leadingWhitespaceToken := 0 for { next := l.readRune() switch next { case runes.SPACE, runes.TAB: escaped = false - whitespaceCount++ case runes.EOF: tok.SetEnd(l.input.InputPosition, l.input.TextPosition) - tok.Literal.Start += uint32(leadingWhitespaceToken) - tok.Literal.End -= uint32(whitespaceCount) return case runes.QUOTE, runes.CARRIAGERETURN, runes.LINETERMINATOR: if escaped { @@ -414,19 +410,11 @@ func (l *Lexer) readSingleLineString(tok *token.Token) { } tok.SetEnd(l.input.InputPosition-1, l.input.TextPosition) - tok.Literal.Start += uint32(leadingWhitespaceToken) - tok.Literal.End -= uint32(whitespaceCount) return case runes.BACKSLASH: escaped = !escaped - whitespaceCount = 0 default: - if !reachedFirstNonWhitespace { - reachedFirstNonWhitespace = true - leadingWhitespaceToken = whitespaceCount - } escaped = false - whitespaceCount = 0 } } } diff --git a/v2/pkg/lexer/lexer_test.go b/v2/pkg/lexer/lexer_test.go index 427db2b8d2..1fc0e532b1 100644 --- a/v2/pkg/lexer/lexer_test.go +++ b/v2/pkg/lexer/lexer_test.go @@ -2,7 +2,6 @@ package lexer import ( "encoding/json" - "fmt" "os" "testing" @@ -15,6 +14,7 @@ import ( ) func TestLexer_Peek_Read(t *testing.T) { + t.Parallel() type checkFunc func(lex *Lexer, i int) @@ -34,11 +34,11 @@ func TestLexer_Peek_Read(t *testing.T) { return func(lex *Lexer, i int) { tok := lex.Read() if k != tok.Keyword { - panic(fmt.Errorf("mustRead: want(keyword): %s, got: %s [check: %d]", k.String(), tok.String(), i)) + t.Errorf("mustRead: want(keyword): %q, got: %q [check: %d]", k.String(), tok.String(), i) } gotLiteral := string(lex.input.ByteSlice(tok.Literal)) if wantLiteral != gotLiteral { - panic(fmt.Errorf("mustRead: want(literal): %s, got: %s [check: %d]", wantLiteral, gotLiteral, i)) + t.Errorf("mustRead: want(literal): %q, got: %q [check: %d]", wantLiteral, gotLiteral, i) } } } @@ -54,16 +54,16 @@ func TestLexer_Peek_Read(t *testing.T) { tok := lex.Read() if lineStart != tok.TextPosition.LineStart { - panic(fmt.Errorf("mustReadPosition: want(lineStart): %d, got: %d [check: %d]", lineStart, tok.TextPosition.LineStart, i)) + t.Errorf("mustReadPosition: want(lineStart): %d, got: %d [check: %d]", lineStart, tok.TextPosition.LineStart, i) } if charStart != tok.TextPosition.CharStart { - panic(fmt.Errorf("mustReadPosition: want(charStart): %d, got: %d [check: %d]", charStart, tok.TextPosition.CharStart, i)) + t.Errorf("mustReadPosition: want(charStart): %d, got: %d [check: %d]", charStart, tok.TextPosition.CharStart, i) } if lineEnd != tok.TextPosition.LineEnd { - panic(fmt.Errorf("mustReadPosition: want(lineEnd): %d, got: %d [check: %d]", lineEnd, tok.TextPosition.LineEnd, i)) + t.Errorf("mustReadPosition: want(lineEnd): %d, got: %d [check: %d]", lineEnd, tok.TextPosition.LineEnd, i) } if charEnd != tok.TextPosition.CharEnd { - panic(fmt.Errorf("mustReadPosition: want(charEnd): %d, got: %d [check: %d]", charEnd, tok.TextPosition.CharEnd, i)) + t.Errorf("mustReadPosition: want(charEnd): %d, got: %d [check: %d]", charEnd, tok.TextPosition.CharEnd, i) } } } @@ -72,7 +72,7 @@ func TestLexer_Peek_Read(t *testing.T) { return func(lex *Lexer, i int) { got := lex.peekWhitespaceLength() if want != got { - panic(fmt.Errorf("mustPeekWhitespaceLength: want: %d, got: %d [check: %d]", want, got, i)) + t.Errorf("mustPeekWhitespaceLength: want: %d, got: %d [check: %d]", want, got, i) } } } @@ -174,54 +174,80 @@ func TestLexer_Peek_Read(t *testing.T) { run("-1.758E11", mustRead(keyword.SUB, "-"), mustRead(keyword.FLOAT, "1.758E11")) }) - t.Run("read single line string", func(t *testing.T) { - run("\"foo\"", mustRead(keyword.STRING, "foo")) + t.Run("read string", func(t *testing.T) { + run(`"foo"`, mustRead(keyword.STRING, `foo`)) }) - t.Run("read single line string with leading/trailing whitespace", func(t *testing.T) { - run("\" foo \"", mustRead(keyword.STRING, "foo")) + t.Run("read string with leading/trailing whitespace", func(t *testing.T) { + run("\" \tfoo\t \"", mustRead(keyword.STRING, " \tfoo\t ")) }) t.Run("peek incomplete string as quote", func(t *testing.T) { - run("\"foo", mustRead(keyword.STRING, "foo")) + run(`"foo`, mustRead(keyword.STRING, "foo")) }) - t.Run("read single line string with escaped quote", func(t *testing.T) { - run("\"foo \\\" bar\"", mustRead(keyword.STRING, "foo \\\" bar")) + t.Run("read string with escaped quote", func(t *testing.T) { + run(`"foo \" bar"`, mustRead(keyword.STRING, `foo \" bar`)) }) - t.Run("read single line string with escaped backslash", func(t *testing.T) { - run("\"foo \\\\ bar\"", mustRead(keyword.STRING, "foo \\\\ bar")) + t.Run("read string with escaped backslash", func(t *testing.T) { + run(`"foo \\ bar"`, mustRead(keyword.STRING, `foo \\ bar`)) }) - t.Run("read multi line string with escaped quote", func(t *testing.T) { - run("\"\"\"foo \\\" bar\"\"\"", mustRead(keyword.BLOCKSTRING, "foo \\\" bar")) + t.Run("read block string with escaped quote", func(t *testing.T) { + run(`"""foo \" bar"""`, mustRead(keyword.BLOCKSTRING, `foo \" bar`)) }) - t.Run("read multi line string with two escaped quotes", func(t *testing.T) { - run("\"\"\"foo \"\" bar\"\"\"", mustRead(keyword.BLOCKSTRING, "foo \"\" bar")) + t.Run("read block string with two escaped quotes", func(t *testing.T) { + run(`"""foo "" bar"""`, mustRead(keyword.BLOCKSTRING, `foo "" bar`)) }) - t.Run("read multi line string", func(t *testing.T) { + t.Run("read block string padded with whitespaces", func(t *testing.T) { + run(`""" foo bar """`, mustRead(keyword.BLOCKSTRING, `foo bar`)) + }) + t.Run("read block string", func(t *testing.T) { run("\"\"\"\nfoo\nbar\"\"\"", mustRead(keyword.BLOCKSTRING, "foo\nbar")) }) - t.Run("read multi line string with carriage return", func(t *testing.T) { + t.Run("read block string with carriage return", func(t *testing.T) { run("\"\"\"\r\nfoo\r\nbar\"\"\"", mustRead(keyword.BLOCKSTRING, "foo\r\nbar")) }) - t.Run("read multi line string with escaped backslash", func(t *testing.T) { - run("\"\"\"foo \\\\ bar\"\"\"", mustRead(keyword.BLOCKSTRING, "foo \\\\ bar")) + t.Run("read block string with escaped backslash", func(t *testing.T) { + run(`"""foo \\ bar"""`, mustRead(keyword.BLOCKSTRING, `foo \\ bar`)) }) - t.Run("read multi line string with leading/trailing space", func(t *testing.T) { + t.Run("read block string with leading/trailing space", func(t *testing.T) { run(`""" foo """`, mustRead(keyword.BLOCKSTRING, "foo")) }) - t.Run("read multi line string with trailing leading/trailing tab", func(t *testing.T) { + t.Run("read block string incomplete trailing space", func(t *testing.T) { + run(`"""foo unfinished `, mustRead(keyword.BLOCKSTRING, "foo unfinished")) + }) + t.Run("read ident and block string incomplete empty", func(t *testing.T) { + run(`union"""`, + mustRead(keyword.IDENT, "union"), + mustRead(keyword.BLOCKSTRING, ""), + ) + }) + t.Run("read ident and block string incomplete", func(t *testing.T) { + run(`union"""incomplete str`, + mustRead(keyword.IDENT, "union"), + mustRead(keyword.BLOCKSTRING, "incomplete str"), + ) + }) + t.Run("read block string with surrounding tabs", func(t *testing.T) { run(`""" foo """`, mustRead(keyword.BLOCKSTRING, "foo")) }) - t.Run("read multi line string with trailing leading/trailing LT", func(t *testing.T) { + t.Run("read block string with leading/trailing newlines", func(t *testing.T) { run(`""" - foo + foo """`, mustRead(keyword.BLOCKSTRING, "foo")) }) - t.Run("complex multi line string", func(t *testing.T) { + t.Run("read block string with common indent", func(t *testing.T) { + run(`""" + indented + lines + a + b +"""`, mustRead(keyword.BLOCKSTRING, "indented\n\tlines\n\t\ta\n\t\tb")) + }) + t.Run("complex block string", func(t *testing.T) { run("\"\"\"block string uses \\\"\"\"\n\"\"\"", mustRead(keyword.BLOCKSTRING, "block string uses \\\"\"\"")) }) - t.Run("complex multi line string with carriage return", func(t *testing.T) { + t.Run("complex block string with carriage return", func(t *testing.T) { run("\"\"\"block string uses \\\"\"\"\r\n\"\"\"", mustRead(keyword.BLOCKSTRING, "block string uses \\\"\"\"")) }) - t.Run("read multi line string with trailing leading/trailing whitespace combination", func(t *testing.T) { + t.Run("read block string with leading/trailing whitespace combination", func(t *testing.T) { run(` """ foo """`, mustRead(keyword.BLOCKSTRING, "foo")) diff --git a/v2/pkg/middleware/operation_complexity/operation_complexity_test.go b/v2/pkg/middleware/operation_complexity/operation_complexity_test.go index 169386e066..256c2c72c0 100644 --- a/v2/pkg/middleware/operation_complexity/operation_complexity_test.go +++ b/v2/pkg/middleware/operation_complexity/operation_complexity_test.go @@ -622,7 +622,7 @@ scalar Boolean scalar ID @custom(typeName: "string") "Directs the executor to include this field or fragment only when the argument is true." directive @include( - " Included when true." + "Included when true." if: Boolean! ) on FIELD | FRAGMENT_SPREAD | INLINE_FRAGMENT "Directs the executor to skip this field or fragment when the argument is true."