diff --git a/router-tests/persisted_operations_over_get_test.go b/router-tests/persisted_operations_over_get_test.go index f1bcd4e1e5..836ff00beb 100644 --- a/router-tests/persisted_operations_over_get_test.go +++ b/router-tests/persisted_operations_over_get_test.go @@ -246,7 +246,7 @@ func TestPersistedSubscriptionOverGET(t *testing.T) { Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "a78014f326504cdcc3ed9c4440c989ca0ac7ef237f6379ea7fee0ffde5ea71cb"}}`), Header: map[string][]string{ "Content-Type": {"application/json"}, - "Accept": {"text/event-stream"}, + "Accept": {"text/event-stream,application/json"}, "Connection": {"keep-alive"}, "Cache-Control": {"no-cache"}, diff --git a/router/core/flushwriter.go b/router/core/flushwriter.go index b0a5d7f595..432bdb2d08 100644 --- a/router/core/flushwriter.go +++ b/router/core/flushwriter.go @@ -3,12 +3,13 @@ package core import ( "bytes" "context" + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "io" "mime" "net/http" - - "github.com/wundergraph/astjson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "strconv" + "strings" ) const ( @@ -183,14 +184,42 @@ func setSubscriptionHeaders(wgParams SubscriptionParams, r *http.Request, w http func NegotiateSubscriptionParams(r *http.Request) SubscriptionParams { q := r.URL.Query() - acceptHeader := r.Header.Get("Accept") + acceptHeaders := r.Header.Get("Accept") + elements := strings.Split(acceptHeaders, ",") + // Per RFC 9110, Accept header can be in the form`text/event-stream,application/json`, with an optional q-value to + // specify preference. We want to parse this and find the best option to use, and default to the first option if no + // q-value is provided. + // Eventually a solution will be in the stdlib: see https://github.com/golang/go/issues/19307, at which point we should + // remove this + var ( + useMultipart = false + useSse = q.Has(WgSseParam) + bestType = "" + bestQ = -1.0 // Default to lowest possible q-value + ) + + for _, acceptHeader := range elements { + mediaType, params, _ := mime.ParseMediaType(acceptHeader) + qValue := 1.0 // Default quality factor + if qStr, exists := params["q"]; exists { // If a quality factor exists, parse it and prefer it + if parsedQ, err := strconv.ParseFloat(qStr, 64); err == nil { + qValue = parsedQ + } + } - mediaType, _, _ := mime.ParseMediaType(acceptHeader) + // Find the media type with the highest q-value. If none is provided, it will default to the first option + // in the header, per https://www.rfc-editor.org/rfc/rfc9110.html#name-accept + if qValue > bestQ { + bestQ = qValue + bestType = mediaType + } + } subscribeOnce := q.Has(WgSubscribeOnceParam) - useMultipart := mediaType == multipartMime + useSse = useSse || bestType == sseMimeType + useMultipart = bestType == multipartMime return SubscriptionParams{ - UseSse: q.Has(WgSseParam) || mediaType == sseMimeType, + UseSse: useSse, SubscribeOnce: subscribeOnce, UseMultipart: useMultipart, } diff --git a/router/core/flushwriter_test.go b/router/core/flushwriter_test.go new file mode 100644 index 0000000000..d5a1b2cff0 --- /dev/null +++ b/router/core/flushwriter_test.go @@ -0,0 +1,123 @@ +package core + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "net/url" + "testing" +) + +func TestNegotiateSubscriptionParams(t *testing.T) { + type args struct { + r *http.Request + } + tests := []struct { + name string + args args + want SubscriptionParams + }{ + { + name: "No matching headers/subscribe once", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "test"}, + Header: http.Header{ + "Accept": []string{"test,text/event-stream"}, + }}}, + want: SubscriptionParams{ + UseSse: false, + SubscribeOnce: false, + UseMultipart: false, + }, + }, + { + name: "Subscribe once", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "wg_subscribe_once"}, + Header: http.Header{ + "Accept": []string{"text/event-stream,application/json"}, + }}}, + want: SubscriptionParams{ + UseSse: true, + SubscribeOnce: true, + UseMultipart: false, + }, + }, + { + name: "SSE with query", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "wg_sse"}, + Header: http.Header{ + "Accept": []string{"application/json"}, + }}}, + want: SubscriptionParams{ + UseSse: true, + SubscribeOnce: false, + UseMultipart: false, + }, + }, + { + name: "SSE header", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "test"}, + Header: http.Header{ + "Accept": []string{"text/event-stream,application/json"}, + }}}, + want: SubscriptionParams{ + UseSse: true, + SubscribeOnce: false, + UseMultipart: false, + }, + }, + { + name: "Multipart header", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "test"}, + Header: http.Header{ + "Accept": []string{"multipart/mixed,application/json"}, + }}}, + want: SubscriptionParams{ + UseSse: false, + SubscribeOnce: false, + UseMultipart: true, + }, + }, + { + name: "Respect q preference (multipart wins)", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "test"}, + Header: http.Header{ + "Accept": []string{"text/event-stream;q=0.9,application/json;q=0.8,multipart/mixed;q=1.0"}, + }}}, + want: SubscriptionParams{ + UseSse: false, + SubscribeOnce: false, + UseMultipart: true, + }, + }, + { + name: "Respect order (SSE wins)", + args: args{ + r: &http.Request{ + URL: &url.URL{RawQuery: "test"}, + Header: http.Header{ + "Accept": []string{"text/event-stream,application/json,multipart/mixed"}, + }}}, + want: SubscriptionParams{ + UseSse: true, + SubscribeOnce: false, + UseMultipart: false, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, NegotiateSubscriptionParams(tt.args.r), "NegotiateSubscriptionParams(%v)", tt.args.r) + }) + } +}