diff --git a/router-tests/subscriptions/websocket_test.go b/router-tests/subscriptions/websocket_test.go index 981d72fe2c..723b80fee3 100644 --- a/router-tests/subscriptions/websocket_test.go +++ b/router-tests/subscriptions/websocket_test.go @@ -2165,13 +2165,53 @@ func TestWebSockets(t *testing.T) { require.NoError(t, err) require.Equal(t, "error", res.Type) require.Equal(t, "1", res.ID) - require.JSONEq(t, `[{"message":"operation '9a41d21da2823195ad42c11d51e9ad3345824abdabf567b3615a235843a1fcc7' for client 'my-client' not found"}]`, + require.Equal(t, `[{"message":"PersistedQueryNotFound","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]`, string(res.Payload)) require.NoError(t, conn.Close()) }) }) + t.Run("unknown operation gets rejected when safelist is enabled with log unknown", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithPersistedOperationsConfig(config.PersistedOperationsConfig{ + LogUnknown: true, + Safelist: config.SafelistConfiguration{Enabled: true}, + }), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, []byte(`{"graphql-client-name": "my-client"}`)) + err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { employeeUpdated(employeeID: 1) { id } }"}`), + }) + require.NoError(t, err) + var res testenv.WebSocketMessage + err = testenv.WSReadJSON(t, conn, &res) + require.NoError(t, err) + require.Equal(t, "error", res.Type) + require.Equal(t, "1", res.ID) + require.Equal(t, `[{"message":"PersistedQueryNotFound","extensions":{"code":"PERSISTED_QUERY_NOT_FOUND"}}]`, + string(res.Payload)) + + logEntries := xEnv.Observer().FilterMessageSnippet("Unknown persisted operation found").All() + require.Len(t, logEntries, 1) + requestContext := logEntries[0].ContextMap() + require.Contains(t, requestContext["query"], "subscription { employeeUpdated(employeeID: 1) { id } }") + require.Equal(t, "9a41d21da2823195ad42c11d51e9ad3345824abdabf567b3615a235843a1fcc7", requestContext["sha256Hash"]) + + require.NoError(t, conn.Close()) + }) + }) + t.Run("known hash passes when safelist is enabled", func(t *testing.T) { t.Parallel() diff --git a/router/core/batch.go b/router/core/batch.go index cd8670aafc..f0f745e35a 100644 --- a/router/core/batch.go +++ b/router/core/batch.go @@ -26,11 +26,6 @@ type BatchedOperationId struct{} const defaultBufioReaderSize = 4096 -const ( - ExtensionCodeBatchSizeExceeded = "BATCH_LIMIT_EXCEEDED" - ExtensionCodeBatchSubscriptionsUnsupported = "BATCHING_SUBSCRIPTION_UNSUPPORTED" -) - type HandlerOpts struct { MaxEntriesPerBatch int MaxRoutines int @@ -117,7 +112,7 @@ func processBatchedRequest(w http.ResponseWriter, r *http.Request, handlerOpts H statusCode: http.StatusOK, } if !handlerOpts.OmitExtensions { - maxError.extensionCode = ExtensionCodeBatchSizeExceeded + maxError.extensionCode = ExtCodeErrBatchSizeExceeded } return maxError } diff --git a/router/core/errors.go b/router/core/errors.go index c287eb2ea9..fdf811286d 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -54,6 +54,13 @@ type ( } ) +const ( + ExtCodeErrPersistedQueryNotFound = "PERSISTED_QUERY_NOT_FOUND" + ExtCodeErrErrorRequestCanceled = "REQUEST_CANCELED" + ExtCodeErrBatchSizeExceeded = "BATCH_LIMIT_EXCEEDED" + ExtCodeErrBatchSubscriptionsUnsupported = "BATCHING_SUBSCRIPTION_UNSUPPORTED" +) + func getErrorType(err error) errorType { if errors.Is(err, ErrRateLimitExceeded) { return errorTypeRateLimit @@ -343,7 +350,7 @@ func writeOperationError(r *http.Request, w http.ResponseWriter, requestLogger * var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError switch { case errors.Is(err, context.Canceled): - newErr := NewHttpGraphqlError("request canceled", "REQUEST_CANCELED", http.StatusOK) + newErr := NewHttpGraphqlError("request canceled", ExtCodeErrErrorRequestCanceled, http.StatusOK) writeRequestErrors(writeRequestErrorsParams{ request: r, writer: w, @@ -362,7 +369,7 @@ func writeOperationError(r *http.Request, w http.ResponseWriter, requestLogger * headerPropagation: propagation, }) case errors.As(err, &poNotFoundErr): - newErr := NewHttpGraphqlError("PersistedQueryNotFound", "PERSISTED_QUERY_NOT_FOUND", http.StatusOK) + newErr := NewHttpGraphqlError("PersistedQueryNotFound", ExtCodeErrPersistedQueryNotFound, http.StatusOK) writeRequestErrors(writeRequestErrorsParams{ request: r, writer: w, diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 1363d637bc..b161384d32 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -765,7 +765,7 @@ func (h *PreHandler) handleOperation(req *http.Request, httpOperation *httpOpera statusCode: http.StatusBadRequest, } if !h.omitBatchExtensions { - unsupportedErr.extensionCode = ExtensionCodeBatchSubscriptionsUnsupported + unsupportedErr.extensionCode = ExtCodeErrBatchSubscriptionsUnsupported } return unsupportedErr } diff --git a/router/core/websocket.go b/router/core/websocket.go index 0adbd02461..c19aa08be9 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -806,10 +806,24 @@ func (h *WebSocketConnectionHandler) requestError(err error) error { } func (h *WebSocketConnectionHandler) writeErrorMessage(operationID string, err error) error { - gqlErrors := []graphqlError{ - {Message: err.Error()}, + var gqlErr graphqlError + + var poNotFoundErr *persistedoperation.PersistentOperationNotFoundError + switch { + case errors.As(err, &poNotFoundErr): + // We follow the same pattern of not mentioning the sha256hash + // in the normal http requests for the same case + gqlErr = graphqlError{ + Message: "PersistedQueryNotFound", + Extensions: &Extensions{ + Code: ExtCodeErrPersistedQueryNotFound, + }, + } + default: + gqlErr = graphqlError{Message: err.Error()} } - payload, err := json.Marshal(gqlErrors) + + payload, err := json.Marshal([]graphqlError{gqlErr}) if err != nil { return fmt.Errorf("encoding GraphQL errors: %w", err) }