Skip to content
Merged
254 changes: 254 additions & 0 deletions router-tests/security_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package integration

import (
"fmt"
"github.com/wundergraph/cosmo/router/core"
"net/http"
"testing"

Expand Down Expand Up @@ -109,3 +111,255 @@ func TestParserHardLimits(t *testing.T) {
})
})
}

func TestQueryNamingLimits(t *testing.T) {
t.Parallel()

t.Run("verify operation query naming limits", func(t *testing.T) {
t.Parallel()

t.Run("with large query name and no operation name", func(t *testing.T) {
t.Parallel()
maxLength := 2
queryName := "longstring"

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
}, func(t *testing.T, xEnv *testenv.Environment) {
expectedErrorMessage := fmt.Sprintf(`{"errors":[{"message":"operation name of length %d exceeds max length of %d"}]}`, len(queryName), maxLength)

resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resPost.Body)
require.Equal(t, http.StatusBadRequest, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resGet.Body)
require.Equal(t, http.StatusBadRequest, resGet.Response.StatusCode)
})
})

t.Run("with large query name and small operation name", func(t *testing.T) {
t.Parallel()
maxLength := 6
queryName := "longstring"
operationNameGet := `short`
operationNamePost := `"short"`

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
}, func(t *testing.T, xEnv *testenv.Environment) {
expectedErrorMessage := fmt.Sprintf(`{"errors":[{"message":"operation name of length %d exceeds max length of %d"}]}`, len(queryName), maxLength)

resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
OperationName: []byte(operationNamePost),
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resPost.Body)
require.Equal(t, http.StatusBadRequest, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
OperationName: []byte(operationNameGet),
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resGet.Body)
require.Equal(t, http.StatusBadRequest, resGet.Response.StatusCode)
})
})

t.Run("with small query name and large operation name", func(t *testing.T) {
t.Parallel()

maxLength := 6
queryName := "short"
operationNameGet := `longname`
operationNamePost := `"longname"`

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
}, func(t *testing.T, xEnv *testenv.Environment) {
expectedErrorMessage := fmt.Sprintf(`{"errors":[{"message":"operation name of length %d exceeds max length of %d"}]}`, len(operationNameGet), maxLength)

resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
OperationName: []byte(operationNamePost),
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resPost.Body)
require.Equal(t, http.StatusBadRequest, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
OperationName: []byte(operationNameGet),
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resGet.Body)
require.Equal(t, http.StatusBadRequest, resGet.Response.StatusCode)
})
})

t.Run("with small query name and small operation name", func(t *testing.T) {
t.Parallel()

liitSize := 7
queryName := "short"
operationNameGet := `short`
operationNamePost := `"short"`

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = liitSize
},
}, func(t *testing.T, xEnv *testenv.Environment) {
resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
OperationName: []byte(operationNamePost),
})
require.NoError(t, err)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, resPost.Body)
require.Equal(t, http.StatusOK, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + queryName + " { employees { id } }",
OperationName: []byte(operationNameGet),
})
require.NoError(t, err)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, resGet.Body)
require.Equal(t, http.StatusOK, resGet.Response.StatusCode)
})
})

t.Run("with multiple queries of which one is large", func(t *testing.T) {
t.Parallel()

maxLength := 6
query1Name := "short"
query2Name := "longstring"

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
}, func(t *testing.T, xEnv *testenv.Environment) {
expectedErrorMessage := fmt.Sprintf(`{"errors":[{"message":"operation name of length %d exceeds max length of %d"}]}`, len(query2Name), maxLength)

resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + query1Name + " { employees { id } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resPost.Body)
require.Equal(t, http.StatusBadRequest, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + query1Name + " { employees { id } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resGet.Body)
require.Equal(t, http.StatusBadRequest, resGet.Response.StatusCode)
})
})

t.Run("with multiple queries of which both are small", func(t *testing.T) {
t.Parallel()

maxLength := 6
query1Name := "short1"
query2Name := "short2"

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
}, func(t *testing.T, xEnv *testenv.Environment) {
resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + query1Name + " { employees { id } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, resPost.Body)
require.Equal(t, http.StatusOK, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + query1Name + " { employees { id } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, resGet.Body)
require.Equal(t, http.StatusOK, resGet.Response.StatusCode)
})
})

t.Run("with large queries with max length of 0 where the validation is not enabled", func(t *testing.T) {
t.Parallel()

maxLength := 0
query1Name := "longlonglonglonglonglonglonglonglonglong1"
query2Name := "short2"

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
}, func(t *testing.T, xEnv *testenv.Environment) {
resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + query1Name + " { employees { id } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, resPost.Body)
require.Equal(t, http.StatusOK, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + query1Name + " { employees { id } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, resGet.Body)
require.Equal(t, http.StatusOK, resGet.Response.StatusCode)
})
})

// In case of introspection checks, we could potentially early return
t.Run("with multiple queries with introspection disabled", func(t *testing.T) {
t.Parallel()

maxLength := 6
query1Name := "longquery"
query2Name := "short2"

testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.OperationNameLengthLimit = maxLength
},
RouterOptions: []core.Option{
core.WithIntrospection(false),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
expectedErrorMessage := fmt.Sprintf(`{"errors":[{"message":"operation name of length %d exceeds max length of %d"}]}`, len(query1Name), maxLength)

resPost, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: "query " + query1Name + " { __schema { __typename } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resPost.Body)
require.Equal(t, http.StatusBadRequest, resPost.Response.StatusCode)

resGet, err := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
Query: "query " + query1Name + " { __schema { __typename } } query " + query2Name + " { employees { id } }",
})
require.NoError(t, err)
require.JSONEq(t, expectedErrorMessage, resGet.Body)
require.Equal(t, http.StatusBadRequest, resGet.Response.StatusCode)
})
})
})
}
1 change: 1 addition & 0 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ func (s *graphServer) buildGraphMux(
MaxDepth: s.Config.securityConfiguration.ParserLimits.ApproximateDepthLimit,
MaxFields: s.Config.securityConfiguration.ParserLimits.TotalFieldsLimit,
},
OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags,
DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError,
Expand Down
9 changes: 9 additions & 0 deletions router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,15 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson
}
}

if operationKit.isOperationNameLengthLimitExceeded(operationKit.parsedOperation.Request.OperationName) {
return &httpGraphqlError{
message: fmt.Sprintf("operation name of length %d exceeds max length of %d",
len(operationKit.parsedOperation.Request.OperationName),
operationKit.operationProcessor.operationNameLengthLimit),
statusCode: http.StatusBadRequest,
}
}

// Compute the operation sha256 hash as soon as possible for observability reasons
if h.shouldComputeOperationSha256(operationKit) {
if err := operationKit.ComputeOperationSha256(); err != nil {
Expand Down
33 changes: 33 additions & 0 deletions router/core/operation_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ type OperationProcessorOptions struct {
DisableExposingVariablesContentOnValidationError bool
ComplexityLimits *config.ComplexityLimits
ParserTokenizerLimits astparser.TokenizerLimits
OperationNameLengthLimit int
}

// OperationProcessor provides shared resources to the parseKit and OperationKit.
Expand All @@ -131,6 +132,7 @@ type OperationProcessor struct {
parseKitOptions *parseKitOptions
complexityLimits *config.ComplexityLimits
parserTokenizerLimits astparser.TokenizerLimits
operationNameLengthLimit int
}

// parseKit is a helper struct to parse, normalize and validate operations
Expand Down Expand Up @@ -487,6 +489,14 @@ func (o *OperationKit) isIntrospectionQuery() (result bool, err error) {
ref := possibleOperationDefinitionRefs[i]
name := o.kit.doc.OperationDefinitionNameString(ref)

if o.isOperationNameLengthLimitExceeded(name) {
return false, &httpGraphqlError{
message: fmt.Sprintf("operation name of length %d exceeds max length of %d",
len(name), o.operationProcessor.operationNameLengthLimit),
statusCode: http.StatusBadRequest,
}
}

if o.parsedOperation.Request.OperationName == name {
operationDefinitionRef = ref
break
Expand Down Expand Up @@ -527,6 +537,13 @@ func (o *OperationKit) isIntrospectionQuery() (result bool, err error) {
return false, nil
}

func (o *OperationKit) isOperationNameLengthLimitExceeded(operationName string) bool {
if o.operationProcessor.operationNameLengthLimit == 0 {
return false
}
return len(operationName) > o.operationProcessor.operationNameLengthLimit
}

// Parse parses the operation, populates the document and set the operation type.
// UnmarshalOperationFromBody must be called before calling this method.
func (o *OperationKit) Parse() error {
Expand Down Expand Up @@ -560,6 +577,11 @@ func (o *OperationKit) Parse() error {
isIntrospection, err := o.isIntrospectionQuery()

if err != nil {
var httpGqlError *httpGraphqlError
if errors.As(err, &httpGqlError) {
return httpGqlError
}

return &httpGraphqlError{
message: "could not determine if operation was an introspection query",
statusCode: http.StatusOK,
Expand All @@ -582,13 +604,23 @@ func (o *OperationKit) Parse() error {
o.kit.numOperations++
ref := o.kit.doc.RootNodes[i].Ref
name := string(o.kit.doc.OperationDefinitionNameBytes(ref))

if len(name) == 0 {
anonymousOperationCount++
if anonymousOperationDefinitionRef == -1 {
anonymousOperationDefinitionRef = ref
}
continue
}

if o.isOperationNameLengthLimitExceeded(name) {
return &httpGraphqlError{
message: fmt.Sprintf("operation name of length %d exceeds max length of %d",
len(name), o.operationProcessor.operationNameLengthLimit),
statusCode: http.StatusBadRequest,
}
}

if o.parsedOperation.Request.OperationName == "" {
o.operationDefinitionRef = ref
o.originalOperationNameRef = o.kit.doc.OperationDefinitions[ref].Name
Expand Down Expand Up @@ -1256,6 +1288,7 @@ func NewOperationProcessor(opts OperationProcessorOptions) *OperationProcessor {
parseKitSemaphore: make(chan int, opts.ParseKitPoolSize),
introspectionEnabled: opts.IntrospectionEnabled,
parserTokenizerLimits: opts.ParserTokenizerLimits,
operationNameLengthLimit: opts.OperationNameLengthLimit,
complexityLimits: opts.ComplexityLimits,
parseKitOptions: &parseKitOptions{
apolloCompatibilityFlags: opts.ApolloCompatibilityFlags,
Expand Down
1 change: 1 addition & 0 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ type SecurityConfiguration struct {
ComplexityLimits *ComplexityLimits `yaml:"complexity_limits"`
DepthLimit *QueryDepthConfiguration `yaml:"depth_limit"`
ParserLimits ParserLimitsConfiguration `yaml:"parser_limits"`
OperationNameLengthLimit int `yaml:"operation_name_length_limit" envDefault:"512" env:"SECURITY_OPERATION_NAME_LENGTH_LIMIT"` // 0 is disabled
}

type ParserLimitsConfiguration struct {
Expand Down
6 changes: 6 additions & 0 deletions router/pkg/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,12 @@
}
}
},
"operation_name_length_limit": {
"type": "integer",
"description": "The maximum allowed length of the operation name, 0 allows any length.",
"default": "512",
"minimum": 0
},
"depth_limit": {
"type": "object",
"description": "DEPRECATED (move to complexity_limits.depth): The configuration for adding a max depth limit for query (how many nested levels you can have in a query).",
Expand Down
Loading
Loading