diff --git a/router-tests/http_subscriptions_test.go b/router-tests/http_subscriptions_test.go index 833334082b..7f88e6cfdf 100644 --- a/router-tests/http_subscriptions_test.go +++ b/router-tests/http_subscriptions_test.go @@ -136,9 +136,9 @@ func TestHeartbeats(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) - defer resp.Body.Close() reader := bufio.NewReader(resp.Body) lines := make(chan string, 50) @@ -198,4 +198,68 @@ func TestHeartbeats(t *testing.T) { }) }) }) + + t.Run("should write an error on sse", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"errors":[{"message":"Subgraph forbidden","extensions":{"code":"FORBIDDEN"}}]}`)) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + client := http.Client{ + Timeout: time.Second * 100, + } + + subscribePayload := []byte(`{"query":"subscription { countEmp(max: 5, intervalMilliseconds: 550) }"}`) + + req, err := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload)) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Cache-Control", "no-cache") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + reader := bufio.NewReader(resp.Body) + lines := make(chan string, 50) + + go func() { + defer close(lines) + for { + line, _, err := reader.ReadLine() + if err != nil { + return + } + lines <- string(line) + } + }() + + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "event: next", line) + }) + + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, `data: {"errors":[{"message":"Subscription Upgrade request failed for Subgraph 'employees'.","extensions":{"statusCode":403}}],"data":null}`, line) + }) + + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "", line) + }) + }) + }) } diff --git a/router/core/flushwriter.go b/router/core/flushwriter.go index 90456ffc89..5e80248012 100644 --- a/router/core/flushwriter.go +++ b/router/core/flushwriter.go @@ -143,6 +143,12 @@ func (f *HttpFlushWriter) Flush() (err error) { separation = "" } + // resp sometimes ends with newlines. We need to remove them + // to cleanly add the seperation in the next step. + if bytes.HasSuffix(resp, []byte{'\n'}) { + resp = bytes.TrimRight(resp, "\n") + } + full := flushBreak + string(resp) + separation _, err = f.writer.Write([]byte(full)) if err != nil { diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index cad6df1466..f5475997de 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -430,8 +430,8 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv } } - if wsRw, ok := w.(*websocketResponseWriter); ok { - _ = wsRw.Flush() + if flusher, ok := w.(resolve.SubscriptionResponseWriter); ok { + _ = flusher.Flush() } }