diff --git a/pkg/graphql/input_validation.go b/pkg/graphql/input_validation.go new file mode 100644 index 0000000000..f9c8d8f1b0 --- /dev/null +++ b/pkg/graphql/input_validation.go @@ -0,0 +1,44 @@ +package graphql + +import ( + "github.com/TykTechnologies/graphql-go-tools/pkg/operationreport" + "github.com/TykTechnologies/graphql-go-tools/pkg/variablevalidator" +) + +type InputValidationResult struct { + Valid bool + Errors Errors +} + +func inputValidationResultFromReport(report operationreport.Report) (InputValidationResult, error) { + result := InputValidationResult{ + Valid: false, + Errors: nil, + } + + if !report.HasErrors() { + result.Valid = true + return result, nil + } + + result.Errors = RequestErrorsFromOperationReport(report) + + var err error + if len(report.InternalErrors) > 0 { + err = report.InternalErrors[0] + } + + return result, err +} + +func (r *Request) ValidateInput(schema *Schema) (InputValidationResult, error) { + validator := variablevalidator.NewVariableValidator() + + report := r.parseQueryOnce() + if report.HasErrors() { + return inputValidationResultFromReport(report) + } + validator.Validate(&r.document, &schema.document, []byte(r.OperationName), r.Variables, &report) + + return inputValidationResultFromReport(report) +} diff --git a/pkg/graphql/input_validation_test.go b/pkg/graphql/input_validation_test.go new file mode 100644 index 0000000000..2bfcb62787 --- /dev/null +++ b/pkg/graphql/input_validation_test.go @@ -0,0 +1,30 @@ +package graphql + +import ( + "github.com/TykTechnologies/graphql-go-tools/pkg/starwars" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRequest_ValidateInput(t *testing.T) { + t.Run("Should pass input validation", func(t *testing.T) { + schema := starwarsSchema(t) + request := requestForQuery(t, starwars.FileDroidWithArgAndVarQuery) + request.Variables = []byte(`{"droidID":"testID"}`) + + result, err := request.ValidateInput(schema) + assert.NoError(t, err) + assert.True(t, result.Valid) + assert.Nil(t, result.Errors) + }) + + t.Run("Should fail input validation", func(t *testing.T) { + schema := starwarsSchema(t) + request := requestForQuery(t, starwars.FileDroidWithArgAndVarQuery) + + result, err := request.ValidateInput(schema) + assert.NoError(t, err) + assert.False(t, result.Valid) + assert.Equal(t, `Required variable "$droidID" was not provided, locations: [{Line:1 Column:13}], path: [query]`, result.Errors.Error()) + }) +} diff --git a/pkg/operationreport/externalerror.go b/pkg/operationreport/externalerror.go index 5f8e871172..d43c267b8e 100644 --- a/pkg/operationreport/externalerror.go +++ b/pkg/operationreport/externalerror.go @@ -27,6 +27,8 @@ const ( UnknownFieldOfInputObjectErrMsg = `Field "%s" is not defined by type "%s".` DuplicatedFieldInputObjectErrMsg = `There can be only one input field named "%s".` ValueIsNotAnInputObjectTypeErrMsg = `Expected value of type "%s", found %s.` + VariableNotProvidedErrMsg = `Required variable "$%s" was not provided` + VariableValidationFailedErrMsg = `Validation for variable "%s" failed: %s` ) type ExternalError struct { @@ -242,6 +244,19 @@ func ErrValueIsNotAnInputObjectType(value, inputType ast.ByteSlice, position pos return err } +func ErrVariableNotProvided(name ast.ByteSlice, position position.Position) (err ExternalError) { + err.Message = fmt.Sprintf(VariableNotProvidedErrMsg, name) + err.Locations = LocationsFromPosition(position) + + return err +} + +func ErrVariableValidationFailed(name ast.ByteSlice, message string, position position.Position) (err ExternalError) { + err.Message = fmt.Sprintf(VariableValidationFailedErrMsg, name, message) + err.Locations = LocationsFromPosition(position) + return err +} + func ErrValueDoesntSatisfyString(value, inputType ast.ByteSlice, position position.Position) (err ExternalError) { err.Message = fmt.Sprintf(NotStringErrMsg, inputType, value) err.Locations = LocationsFromPosition(position) diff --git a/pkg/variablevalidator/variablevalidator.go b/pkg/variablevalidator/variablevalidator.go new file mode 100644 index 0000000000..c760a91703 --- /dev/null +++ b/pkg/variablevalidator/variablevalidator.go @@ -0,0 +1,107 @@ +package variablevalidator + +import ( + "bytes" + "context" + "errors" + "fmt" + "github.com/TykTechnologies/graphql-go-tools/pkg/ast" + "github.com/TykTechnologies/graphql-go-tools/pkg/astvisitor" + "github.com/TykTechnologies/graphql-go-tools/pkg/graphqljsonschema" + "github.com/TykTechnologies/graphql-go-tools/pkg/operationreport" + "github.com/buger/jsonparser" +) + +type VariableValidator struct { + walker *astvisitor.Walker + visitor *validatorVisitor +} + +func NewVariableValidator() *VariableValidator { + walker := astvisitor.Walker{} + validator := VariableValidator{ + walker: &walker, + visitor: &validatorVisitor{ + Walker: &walker, + currentOperation: ast.InvalidRef, + }, + } + + validator.walker.RegisterEnterDocumentVisitor(validator.visitor) + validator.walker.RegisterEnterOperationVisitor(validator.visitor) + validator.walker.RegisterLeaveOperationVisitor(validator.visitor) + validator.walker.RegisterEnterVariableDefinitionVisitor(validator.visitor) + + return &validator +} + +type validatorVisitor struct { + *astvisitor.Walker + + operationName, variables []byte + currentOperation int + operation, definition *ast.Document +} + +func (v *validatorVisitor) EnterDocument(operation, definition *ast.Document) { + v.operation, v.definition = operation, definition +} + +func (v *validatorVisitor) EnterVariableDefinition(ref int) { + if v.currentOperation == ast.InvalidRef { + return + } + typeRef := v.operation.VariableDefinitions[ref].Type + + variableName := v.operation.VariableDefinitionNameBytes(ref) + variable, t, _, err := jsonparser.Get(v.variables, string(variableName)) + if t == jsonparser.NotExist && v.operation.TypeIsNonNull(typeRef) { + v.StopWithExternalErr(operationreport.ErrVariableNotProvided(variableName, v.operation.VariableDefinitions[ref].VariableValue.Position)) + return + } + if err != nil { + v.StopWithInternalErr(errors.New("error parsing variables")) + return + } + + if t == jsonparser.String { + variable = []byte(fmt.Sprintf(`"%s"`, string(variable))) + } + + jsonSchema := graphqljsonschema.FromTypeRef(v.operation, v.definition, typeRef) + schemaValidator, err := graphqljsonschema.NewValidatorFromSchema(jsonSchema) + if err != nil { + v.StopWithInternalErr(err) + return + } + if err := schemaValidator.Validate(context.Background(), variable); err != nil { + v.StopWithExternalErr(operationreport.ErrVariableValidationFailed(variableName, err.Error(), v.operation.VariableDefinitions[ref].VariableValue.Position)) + return + } +} + +func (v *validatorVisitor) EnterOperationDefinition(ref int) { + if len(v.operationName) == 0 { + v.currentOperation = ref + return + } + + if bytes.Equal(v.operationName, v.operation.OperationDefinitionNameBytes(ref)) { + v.currentOperation = ref + } +} + +func (v *validatorVisitor) LeaveOperationDefinition(ref int) { + if v.currentOperation == ref { + v.Stop() + } +} + +func (v *VariableValidator) Validate(operation, definition *ast.Document, operationName, variables []byte, report *operationreport.Report) { + if v.visitor != nil { + v.visitor.operationName = operationName + v.visitor.variables = variables + } + + v.walker.Walk(operation, definition, report) +} diff --git a/pkg/variablevalidator/variablevalidator_test.go b/pkg/variablevalidator/variablevalidator_test.go new file mode 100644 index 0000000000..dc17fd0dcb --- /dev/null +++ b/pkg/variablevalidator/variablevalidator_test.go @@ -0,0 +1,123 @@ +package variablevalidator + +import ( + "github.com/TykTechnologies/graphql-go-tools/internal/pkg/unsafeparser" + "github.com/TykTechnologies/graphql-go-tools/pkg/asttransform" + "github.com/TykTechnologies/graphql-go-tools/pkg/operationreport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +const testDefinition = ` +input CustomInput { + requiredField: String! + optionalField: String +} + +type Query{ + simpleQuery(code: ID): String + inputOfInt(code: Int!): String +} + +type Mutation { + customInputNonNull(in: CustomInput!): String +}` + +const ( + testQuery = ` +query testQuery($code: ID!){ + simpleQuery(code: $code) +} +` + testQueryInt = ` +query testQuery($code: Int!){ + inputOfInt(code: $code) +} +` + + customInputMutation = ` +mutation testMutation($in: CustomInput!){ + customInputNonNull(in: $in) +}` + + customMultipleOperation = ` +query testQuery($code: ID!){ + simpleQuery(code: $code) +} + +mutation testMutation($in: CustomInput!){ + customInputNonNull(in: $in) +} +` +) + +func TestVariableValidator(t *testing.T) { + testCases := []struct { + name string + operation string + operationName string + variables string + expectedError string + }{ + { + name: "basic variable query", + operation: testQuery, + variables: `{"code":"NG"}`, + }, + { + name: "basic variable query of int", + operation: testQueryInt, + variables: `{"code":1}`, + }, + { + name: "missing variable", + operation: testQuery, + variables: `{"codes":"NG"}`, + expectedError: `Required variable "$code" was not provided`, + }, + { + name: "no variable passed", + operation: testQuery, + variables: "", + expectedError: `Required variable "$code" was not provided`, + }, + { + name: "nested input variable", + operation: customInputMutation, + variables: `{"in":{"optionalField":"test"}}`, + expectedError: `Validation for variable "in" failed: validation failed: /: {"optionalField":"te... "requiredField" value is required`, + }, + { + name: "multiple operation should validate first operation", + operation: customMultipleOperation, + variables: `{"code":"NG"}`, + }, + { + name: "multiple operation should validate operation name", + operation: customMultipleOperation, + operationName: "testMutation", + variables: `{"in":{"requiredField":"test"}}`, + }, + } + for _, c := range testCases { + t.Run(c.name, func(t *testing.T) { + definitionDocument := unsafeparser.ParseGraphqlDocumentString(testDefinition) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&definitionDocument)) + + operationDocument := unsafeparser.ParseGraphqlDocumentString(c.operation) + + report := operationreport.Report{} + validator := NewVariableValidator() + validator.Validate(&operationDocument, &definitionDocument, []byte(c.operationName), []byte(c.variables), &report) + + if c.expectedError == "" && report.HasErrors() { + t.Fatalf("expected no error, instead got %s", report.Error()) + } + if c.expectedError != "" { + require.Equal(t, 1, len(report.ExternalErrors)) + assert.Equal(t, c.expectedError, report.ExternalErrors[0].Message) + } + }) + } +}