From 0cba6b3f136012bff4d830f47dcffdf149237d24 Mon Sep 17 00:00:00 2001 From: Dave Freilich Date: Fri, 8 Nov 2024 10:14:40 +0200 Subject: [PATCH] fix(router): ensure L2 cache uses client name in key --- .../automatic_persisted_queries_test.go | 34 +++++------------- .../persisted_operations_over_get_test.go | 6 ++-- router-tests/persisted_operations_test.go | 21 ++++++++--- router/core/errors.go | 4 +-- router/core/graphql_prehandler.go | 2 +- router/core/operation_processor.go | 35 ++++++++++--------- router/core/websocket.go | 2 +- 7 files changed, 49 insertions(+), 55 deletions(-) diff --git a/router-tests/automatic_persisted_queries_test.go b/router-tests/automatic_persisted_queries_test.go index de4f019517..a896208be7 100644 --- a/router-tests/automatic_persisted_queries_test.go +++ b/router-tests/automatic_persisted_queries_test.go @@ -27,11 +27,9 @@ func TestAutomaticPersistedQueries(t *testing.T) { Enabled: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "does-not-exist"}}`), }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res.Body) }) }) @@ -47,12 +45,10 @@ func TestAutomaticPersistedQueries(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { header := make(http.Header) header.Add("graphql-client-name", "my-client") - res0, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res0 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"}}`), Header: header, }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res0.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res0.Body) res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -92,12 +88,10 @@ func TestAutomaticPersistedQueries(t *testing.T) { time.Sleep(3 * time.Second) - res0, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res0 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"}}`), Header: header, }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res0.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res0.Body) }) }) @@ -174,11 +168,9 @@ func TestAutomaticPersistedQueries(t *testing.T) { }, }, }, func(t *testing.T, xEnv *testenv.Environment) { - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "does-not-exist"}}`), }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res.Body) }) }) @@ -210,12 +202,10 @@ func TestAutomaticPersistedQueries(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { header := make(http.Header) header.Add("graphql-client-name", "my-client") - res0, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res0 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"}}`), Header: header, }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res0.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res0.Body) res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -273,12 +263,10 @@ func TestAutomaticPersistedQueries(t *testing.T) { time.Sleep(3 * time.Second) - res0, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res0 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"}}`), Header: header, }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res0.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res0.Body) }) }) @@ -358,14 +346,11 @@ func BenchmarkAutomaticPersistedQueriesCacheEnabled(b *testing.B) { }, func(b *testing.B, xEnv *testenv.Environment) { header := make(http.Header) header.Add("graphql-client-name", "my-client") - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { details { forename location { ...CountryFields } maritalStatus middlename nationality pastLocations { country { ...CountryFields } name type } pets { class gender name ...AlligatorFields ...CatFields ...DogFields ...MouseFields ...PonyFields } surname } } } fragment CountryFields on Country { key { name } } fragment AlligatorFields on Alligator { __typename class dangerous gender name } fragment CatFields on Cat { __typename class gender name type } fragment DogFields on Dog { __typename breed class gender name } fragment MouseFields on Mouse { __typename class gender name } fragment PonyFields on Pony { __typename class gender name }`, Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "fb51f4141cc4f185fedc9956ae9e047b193edb196c6c095af8be785011a7c2ff"}}`), Header: header, }) - if err != nil { - b.Fatal(err) - } if res.Body != expected { b.Fatalf("unexpected response: %s", res.Body) } @@ -373,14 +358,11 @@ func BenchmarkAutomaticPersistedQueriesCacheEnabled(b *testing.B) { for pb.Next() { header := make(http.Header) header.Add("graphql-client-name", "my-client") - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `{ employees { details { forename location { ...CountryFields } maritalStatus middlename nationality pastLocations { country { ...CountryFields } name type } pets { class gender name ...AlligatorFields ...CatFields ...DogFields ...MouseFields ...PonyFields } surname } } } fragment CountryFields on Country { key { name } } fragment AlligatorFields on Alligator { __typename class dangerous gender name } fragment CatFields on Cat { __typename class gender name type } fragment DogFields on Dog { __typename breed class gender name } fragment MouseFields on Mouse { __typename class gender name } fragment PonyFields on Pony { __typename class gender name }`, Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "fb51f4141cc4f185fedc9956ae9e047b193edb196c6c095af8be785011a7c2ff"}}`), Header: header, }) - if err != nil { - b.Fatal(err) - } if res.Body != expected { b.Fatalf("unexpected response: %s", res.Body) } diff --git a/router-tests/persisted_operations_over_get_test.go b/router-tests/persisted_operations_over_get_test.go index 6ed1e83fcf..94dd5dcc67 100644 --- a/router-tests/persisted_operations_over_get_test.go +++ b/router-tests/persisted_operations_over_get_test.go @@ -24,7 +24,7 @@ func TestPersistedOperationOverGET(t *testing.T) { Header: header, }) require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res.Response.StatusCode) + require.Equal(t, http.StatusOK, res.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res.Body) }) }) @@ -92,7 +92,7 @@ func TestAutomatedPersistedQueriesOverGET(t *testing.T) { Header: header, }) require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res.Response.StatusCode) + require.Equal(t, http.StatusOK, res.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res.Body) }) }) @@ -114,7 +114,7 @@ func TestAutomatedPersistedQueriesOverGET(t *testing.T) { Header: header, }) require.NoError(t, err0) - require.Equal(t, http.StatusBadRequest, res0.Response.StatusCode) + require.Equal(t, http.StatusOK, res0.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res0.Body) res1, err1 := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{ diff --git a/router-tests/persisted_operations_test.go b/router-tests/persisted_operations_test.go index 5226280a2c..533c2ad648 100644 --- a/router-tests/persisted_operations_test.go +++ b/router-tests/persisted_operations_test.go @@ -18,11 +18,9 @@ func TestPersistedOperationNotFound(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "does-not-exist"}}`), }) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, res.Response.StatusCode) require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res.Body) }) } @@ -137,6 +135,19 @@ func TestPersistedOperationsCache(t *testing.T) { require.Equal(t, `{"data":{"employees":[{"details":{"pets":null}},{"details":{"pets":null}},{"details":{"pets":[{"name":"Snappy","__typename":"Alligator"}]}},{"details":{"pets":[{"name":"Abby","__typename":"Dog","breed":"GOLDEN_RETRIEVER","class":"MAMMAL","gender":"FEMALE"},{"name":"Survivor","__typename":"Pony"}]}},{"details":{"pets":[{"name":"Blotch","__typename":"Cat","class":"MAMMAL","gender":"FEMALE","type":"STREET"},{"name":"Grayone","__typename":"Cat","class":"MAMMAL","gender":"MALE","type":"STREET"},{"name":"Rusty","__typename":"Cat","class":"MAMMAL","gender":"MALE","type":"STREET"},{"name":"Manya","__typename":"Cat","class":"MAMMAL","gender":"FEMALE","type":"HOME"},{"name":"Peach","__typename":"Cat","class":"MAMMAL","gender":"MALE","type":"STREET"},{"name":"Panda","__typename":"Cat","class":"MAMMAL","gender":"MALE","type":"HOME"},{"name":"Mommy","__typename":"Cat","class":"MAMMAL","gender":"FEMALE","type":"STREET"},{"name":"Terry","__typename":"Cat","class":"MAMMAL","gender":"FEMALE","type":"HOME"},{"name":"Tilda","__typename":"Cat","class":"MAMMAL","gender":"FEMALE","type":"HOME"},{"name":"Vasya","__typename":"Cat","class":"MAMMAL","gender":"MALE","type":"HOME"}]}},{"details":{"pets":null}},{"details":{"pets":null}},{"details":{"pets":[{"name":"Vanson","__typename":"Mouse"}]}},{"details":{"pets":null}},{"details":{"pets":[{"name":"Pepper","__typename":"Cat","class":"MAMMAL","gender":"FEMALE","type":"HOME"}]}}]}}`, res.Body) require.Equal(t, "HIT", res.Response.Header.Get(core.PersistedOperationCacheHeader)) require.Equal(t, "HIT", res.Response.Header.Get(core.ExecutionPlanCacheHeader)) + + header = make(http.Header) + header.Add("graphql-client-name", "not-my-client") + res, err = xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + OperationName: []byte(`"Employees"`), + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "2267510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`), + Header: header, + Variables: []byte(`{"withAligators": false,"withCats": true,"skipDogs": false,"skipMouses": true}`), + }) + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"persisted query not found","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]}`, res.Body) + require.Equal(t, "", res.Response.Header.Get(core.PersistedOperationCacheHeader)) + require.Equal(t, "", res.Response.Header.Get(core.ExecutionPlanCacheHeader)) } retrieveNumberOfCDNRequests := func(t *testing.T, cdnURL string) int { @@ -156,7 +167,7 @@ func TestPersistedOperationsCache(t *testing.T) { testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { sendTwoRequests(t, xEnv) numberOfCDNRequests := retrieveNumberOfCDNRequests(t, xEnv.CDN.URL) - require.Equal(t, 1, numberOfCDNRequests) + require.Equal(t, 2, numberOfCDNRequests) }) }) @@ -170,7 +181,7 @@ func TestPersistedOperationsCache(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { sendTwoRequests(t, xEnv) numberOfCDNRequests := retrieveNumberOfCDNRequests(t, xEnv.CDN.URL) - require.Equal(t, 2, numberOfCDNRequests) + require.Equal(t, 3, numberOfCDNRequests) }) }) } diff --git a/router/core/errors.go b/router/core/errors.go index 8e33307303..1b3e2330f5 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -281,8 +281,8 @@ func writeOperationError(r *http.Request, w http.ResponseWriter, requestLogger * case errors.As(err, &httpErr): writeRequestErrors(r, w, httpErr.StatusCode(), requestErrorsFromHttpError(httpErr), requestLogger) case errors.As(err, &poNotFoundErr): - newErr := NewHttpGraphqlError("persisted query not found", "PERSISTED_QUERY_NOT_FOUND", http.StatusBadRequest) - writeRequestErrors(r, w, http.StatusBadRequest, requestErrorsFromHttpError(newErr), requestLogger) + newErr := NewHttpGraphqlError("persisted query not found", "PERSISTED_QUERY_NOT_FOUND", http.StatusOK) + writeRequestErrors(r, w, http.StatusOK, requestErrorsFromHttpError(newErr), requestLogger) case errors.As(err, &reportErr): report := reportErr.Report() logInternalErrorsFromReport(reportErr.Report(), requestLogger) diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index ae6b6430fa..49d11618e7 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -576,7 +576,7 @@ func (h *PreHandler) handleOperation(req *http.Request, buf *bytes.Buffer, varia trace.WithAttributes(requestContext.telemetry.traceAttrs...), ) - cached, err := operationKit.NormalizeOperation(isApq) + cached, err := operationKit.NormalizeOperation(requestContext.operation.clientInfo.Name, isApq) if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) diff --git a/router/core/operation_processor.go b/router/core/operation_processor.go index ce15c5e74a..385f1ccac4 100644 --- a/router/core/operation_processor.go +++ b/router/core/operation_processor.go @@ -360,7 +360,7 @@ func (o *OperationKit) FetchPersistedOperation(ctx context.Context, clientInfo * statusCode: http.StatusOK, } } - fromCache, err := o.loadPersistedOperationFromCache() + fromCache, err := o.loadPersistedOperationFromCache(clientInfo.Name) if err != nil { return false, false, &httpGraphqlError{ statusCode: http.StatusInternalServerError, @@ -368,7 +368,7 @@ func (o *OperationKit) FetchPersistedOperation(ctx context.Context, clientInfo * } } if fromCache { - if isApq, _ := o.persistedOperationCacheKeyHasTtl(); isApq { + if isApq, _ := o.persistedOperationCacheKeyHasTtl(clientInfo.Name); isApq { // if it is an APQ request, we need to save it again to renew the TTL expiration if err = o.operationProcessor.persistedOperationClient.SaveOperation(ctx, clientInfo.Name, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash, o.parsedOperation.NormalizedRepresentation); err != nil { return false, false, err @@ -589,18 +589,18 @@ func (o *OperationKit) Parse() error { // NormalizeOperation normalizes the operation. After normalization the normalized representation of the operation // and variables is available. Also, the final operation ID is generated. -func (o *OperationKit) NormalizeOperation(isApq bool) (bool, error) { +func (o *OperationKit) NormalizeOperation(clientName string, isApq bool) (bool, error) { if o.parsedOperation.IsPersistedOperation { - return o.normalizePersistedOperation(isApq) + return o.normalizePersistedOperation(clientName, isApq) } return o.normalizeNonPersistedOperation() } -func (o *OperationKit) normalizePersistedOperation(isApq bool) (cached bool, err error) { +func (o *OperationKit) normalizePersistedOperation(clientName string, isApq bool) (cached bool, err error) { if o.parsedOperation.NormalizedRepresentation != "" { // when dealing with APQ requests which have a TTL set, we need to renew the TTL - if shouldRenew, skipIncludeNames := o.persistedOperationCacheKeyHasTtl(); shouldRenew { - o.savePersistedOperationToCache(true, skipIncludeNames) + if shouldRenew, skipIncludeNames := o.persistedOperationCacheKeyHasTtl(clientName); shouldRenew { + o.savePersistedOperationToCache(clientName, true, skipIncludeNames) } // normalized operation was loaded from cache return true, nil @@ -638,7 +638,7 @@ func (o *OperationKit) normalizePersistedOperation(isApq bool) (cached bool, err o.parsedOperation.Request.Variables = o.kit.doc.Input.Variables if o.cache != nil && o.cache.persistedOperationNormalizationCache != nil { - o.savePersistedOperationToCache(isApq, skipIncludeNames) + o.savePersistedOperationToCache(clientName, isApq, skipIncludeNames) } return false, nil @@ -750,13 +750,13 @@ func (o *OperationKit) NormalizeVariables() error { return nil } -func (o *OperationKit) loadPersistedOperationFromCache() (ok bool, err error) { +func (o *OperationKit) loadPersistedOperationFromCache(clientName string) (ok bool, err error) { if o.cache == nil || o.cache.persistedOperationNormalizationCache == nil { return false, nil } - cacheKey, ok := o.loadPersistedOperationCacheKey(o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) + cacheKey, ok := o.loadPersistedOperationCacheKey(clientName, o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) if !ok { return false, nil } @@ -790,7 +790,7 @@ func (o *OperationKit) jsonIsNull(variables []byte) bool { return value.Type() == fastjson.TypeNull } -func (o *OperationKit) persistedOperationCacheKeyHasTtl() (bool, []string) { +func (o *OperationKit) persistedOperationCacheKeyHasTtl(clientName string) (bool, []string) { if o.cache == nil || o.cache.persistedOperationVariableNames == nil || o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash == "" { return false, nil } @@ -801,14 +801,14 @@ func (o *OperationKit) persistedOperationCacheKeyHasTtl() (bool, []string) { if !present { return false, variableNames } - cacheKey := o.generatePersistedOperationCacheKey(variableNames) + cacheKey := o.generatePersistedOperationCacheKey(clientName, variableNames) ttl, ok := o.cache.persistedOperationNormalizationCache.GetTTL(cacheKey) return ok && ttl > 0, variableNames } -func (o *OperationKit) savePersistedOperationToCache(isApq bool, skipIncludeVariableNames []string) { - cacheKey := o.generatePersistedOperationCacheKey(skipIncludeVariableNames) +func (o *OperationKit) savePersistedOperationToCache(clientName string, isApq bool, skipIncludeVariableNames []string) { + cacheKey := o.generatePersistedOperationCacheKey(clientName, skipIncludeVariableNames) entry := NormalizationCacheEntry{ operationID: o.parsedOperation.ID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, @@ -828,20 +828,21 @@ func (o *OperationKit) savePersistedOperationToCache(isApq bool, skipIncludeVari o.cache.persistedOperationVariableNamesLock.Unlock() } -func (o *OperationKit) loadPersistedOperationCacheKey(persistedQuerySha256Hash string) (key uint64, ok bool) { +func (o *OperationKit) loadPersistedOperationCacheKey(clientName, persistedQuerySha256Hash string) (key uint64, ok bool) { o.cache.persistedOperationVariableNamesLock.RLock() variableNames, ok := o.cache.persistedOperationVariableNames[persistedQuerySha256Hash] o.cache.persistedOperationVariableNamesLock.RUnlock() if !ok { return 0, false } - key = o.generatePersistedOperationCacheKey(variableNames) + key = o.generatePersistedOperationCacheKey(clientName, variableNames) return key, true } -func (o *OperationKit) generatePersistedOperationCacheKey(skipIncludeVariableNames []string) uint64 { +func (o *OperationKit) generatePersistedOperationCacheKey(clientName string, skipIncludeVariableNames []string) uint64 { _, _ = o.kit.keyGen.WriteString(o.parsedOperation.GraphQLRequestExtensions.PersistedQuery.Sha256Hash) _, _ = o.kit.keyGen.WriteString(o.parsedOperation.Request.OperationName) + _, _ = o.kit.keyGen.WriteString(clientName) o.writeSkipIncludeCacheKeyToKeyGen(skipIncludeVariableNames) sum := o.kit.keyGen.Sum64() o.kit.keyGen.Reset() diff --git a/router/core/websocket.go b/router/core/websocket.go index 989ba0fa07..b2d4d73c7f 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -803,7 +803,7 @@ func (h *WebSocketConnectionHandler) parseAndPlan(payload []byte) (*ParsedOperat startNormalization := time.Now() - if _, err := operationKit.NormalizeOperation(isApq); err != nil { + if _, err := operationKit.NormalizeOperation(h.clientInfo.Name, isApq); err != nil { opContext.normalizationTime = time.Since(startNormalization) return nil, nil, err }