diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index dbc17f870a..3ad51a592c 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -417,7 +417,7 @@ func TestKafkaEvents(t *testing.T) { t.Run("multipart", func(t *testing.T) { t.Parallel() - multipartHeartbeatInterval := time.Second * 5 + subscriptionHeartbeatInterval := time.Second * 5 t.Run("subscribe sync", func(t *testing.T) { t.Parallel() @@ -428,7 +428,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, RouterOptions: []core.Option{ - core.WithMultipartHeartbeatInterval(multipartHeartbeatInterval), + core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), }, }, func(t *testing.T, xEnv *testenv.Environment) { EnsureTopicExists(t, xEnv, topics...) diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index d3235643b4..9e1558db24 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -324,7 +324,7 @@ func TestNatsEvents(t *testing.T) { testenv.Run(t, &testenv.Config{ RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, RouterOptions: []core.Option{ - core.WithMultipartHeartbeatInterval(heartbeatInterval), + core.WithSubscriptionHeartbeatInterval(heartbeatInterval), }, EnableNats: true, TLSConfig: &core.TlsConfig{ @@ -378,7 +378,7 @@ func TestNatsEvents(t *testing.T) { EnableNats: true, TLSConfig: nil, // Force Http/1 RouterOptions: []core.Option{ - core.WithMultipartHeartbeatInterval(heartbeatInterval), + core.WithSubscriptionHeartbeatInterval(heartbeatInterval), }, }, func(t *testing.T, xEnv *testenv.Environment) { diff --git a/router-tests/events/redis_events_test.go b/router-tests/events/redis_events_test.go index 1c287f7d6f..f6c9e54d13 100644 --- a/router-tests/events/redis_events_test.go +++ b/router-tests/events/redis_events_test.go @@ -478,7 +478,7 @@ func TestRedisEvents(t *testing.T) { t.Run("multipart", func(t *testing.T) { t.Parallel() - multipartHeartbeatInterval := time.Second * 5 + subscriptionHeartbeatInterval := time.Second * 5 t.Run("subscribe sync", func(t *testing.T) { t.Parallel() @@ -489,7 +489,7 @@ func TestRedisEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate, EnableRedis: true, RouterOptions: []core.Option{ - core.WithMultipartHeartbeatInterval(multipartHeartbeatInterval), + core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), }, }, func(t *testing.T, xEnv *testenv.Environment) { subscribePayload := []byte(`{"query":"subscription { employeeUpdates { id details { forename surname } }}"}`) diff --git a/router-tests/go.mod b/router-tests/go.mod index 08ab9fbaae..bec8f63fb7 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20250820135159-bf8852195d3f github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223 go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 diff --git a/router-tests/go.sum b/router-tests/go.sum index 4a69658e75..87ab146b7f 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -352,8 +352,8 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220 h1:+imPYcv+XExZ+ofX5jCxtaA7upeys7uWA7RsTZiTTWE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220/go.mod h1:DnYY1alnsgzkanSwbFiFIdXKOuf8dHQWQ2P4BzTc6aI= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223 h1:PUYcDoqkgqDnZVpO0c+y80kR308OQBtFzRPPegr0bIk= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223/go.mod h1:DnYY1alnsgzkanSwbFiFIdXKOuf8dHQWQ2P4BzTc6aI= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/http_subscriptions_test.go b/router-tests/http_subscriptions_test.go new file mode 100644 index 0000000000..833334082b --- /dev/null +++ b/router-tests/http_subscriptions_test.go @@ -0,0 +1,201 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" +) + +func readMultipartPrefix(reader *bufio.Reader) error { + blankHeader, _, err := reader.ReadLine() + if err != nil { + return err + } + + if len(blankHeader) != 0 { + return fmt.Errorf("expected blank header, got %q", blankHeader) + } + + graphQLHeader, _, err := reader.ReadLine() + if err != nil { + return err + } + + if string(graphQLHeader) != "--graphql" { + return fmt.Errorf("expected graphql header, got %q", graphQLHeader) + } + + contentTypeHeader, _, err := reader.ReadLine() + if err != nil { + return err + } + + if string(contentTypeHeader) != "Content-Type: application/json" { + return fmt.Errorf("expected content type header, got %q", contentTypeHeader) + } + + blankFooter, _, err := reader.ReadLine() + if err != nil { + return err + } + + if len(blankFooter) != 0 { + return fmt.Errorf("expected blank footer, got %q", blankFooter) + } + + return nil +} + +func TestHeartbeats(t *testing.T) { + subscriptionHeartbeatInterval := time.Millisecond * 300 + + t.Run("should work correctly for multipart", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + client := http.Client{ + Timeout: time.Second * 100, + } + + subscribePayload := []byte(`{"query":"subscription { countEmp(max: 5, intervalMilliseconds: 550) }"}`) + + req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload)) + resp, gErr := client.Do(req) + require.NoError(t, gErr) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + reader := bufio.NewReader(resp.Body) + + messages := make(chan string, 1) + + go func() { + defer close(messages) + for { + err := readMultipartPrefix(reader) + if err != nil { + return + } + + line, _, err := reader.ReadLine() + if err != nil { + return + } + + fmt.Println(string(line)) + messages <- string(line) + } + }() + + for i := 0; i <= 5; i++ { + testenv.AwaitChannelWithT(t, 5*time.Second, messages, func(t *testing.T, msg string) { + assert.Equal(t, fmt.Sprintf(`{"payload":{"data":{"countEmp":%d}}}`, i), msg) + }) + + testenv.AwaitChannelWithT(t, 5*time.Second, messages, func(t *testing.T, msg string) { + assert.Equal(t, `{}`, msg) + }) + } + + // Channel should be closed after all heartbeats are received + testenv.AwaitChannelWithCloseWithT(t, 5*time.Second, messages, func(t *testing.T, _ string, ok bool) { + require.False(t, ok, "channel should be closed") + }) + }) + }) + + t.Run("should work correctly for sse", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + client := http.Client{ + Timeout: time.Second * 100, + } + + subscribePayload := []byte(`{"query":"subscription { countEmp(max: 5, intervalMilliseconds: 550) }"}`) + + req, err := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload)) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Cache-Control", "no-cache") + + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer resp.Body.Close() + reader := bufio.NewReader(resp.Body) + + lines := make(chan string, 50) + + go func() { + defer close(lines) + for { + line, _, err := reader.ReadLine() + if err != nil { + return + } + lines <- string(line) + } + }() + + // Assert the expected SSE sequence + for i := 0; i <= 5; i++ { + // Expect "event: next" + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "event: next", line) + }) + + // Expect data line with count + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, fmt.Sprintf(`data: {"data":{"countEmp":%d}}`, i), line) + }) + + // Expect blank line + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "", line) + }) + + // Expect heartbeat + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, ":heartbeat", line) + }) + + // Expect blank line after heartbeat + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "", line) + }) + } + + // Expect completion event + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "event: complete", line) + }) + + // Expect empty data line event + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "data: ", line) + }) + + // Expect blank line after complete + testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) { + assert.Equal(t, "", line) + }) + }) + }) +} diff --git a/router-tests/testenv/utils.go b/router-tests/testenv/utils.go index 47bac51fdd..bd4f1842db 100644 --- a/router-tests/testenv/utils.go +++ b/router-tests/testenv/utils.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func AwaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(*testing.T, A), msgAndArgs ...interface{}) { +func AwaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(*testing.T, A), msgAndArgs ...any) { t.Helper() select { @@ -17,3 +17,14 @@ func AwaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, require.Fail(t, "unable to receive message before timeout", msgAndArgs...) } } + +func AwaitChannelWithCloseWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(t *testing.T, item A, ok bool), msgAndArgs ...any) { + t.Helper() + + select { + case args, ok := <-ch: + f(t, args, ok) + case <-time.After(timeout): + require.Fail(t, "unable to receive message before timeout", msgAndArgs...) + } +} diff --git a/router/core/executor.go b/router/core/executor.go index 560f40840e..2a8384bf99 100644 --- a/router/core/executor.go +++ b/router/core/executor.go @@ -83,7 +83,7 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor AllowedSubgraphErrorFields: opts.RouterEngineConfig.SubgraphErrorPropagation.AllowedFields, AllowAllErrorExtensionFields: opts.RouterEngineConfig.SubgraphErrorPropagation.AllowAllExtensionFields, MaxRecyclableParserSize: opts.RouterEngineConfig.Execution.ResolverMaxRecyclableParserSize, - MultipartSubHeartbeatInterval: opts.HeartbeatInterval, + SubscriptionHeartbeatInterval: opts.HeartbeatInterval, MaxSubscriptionFetchTimeout: opts.RouterEngineConfig.Execution.SubscriptionFetchTimeout, } diff --git a/router/core/flushwriter.go b/router/core/flushwriter.go index 5e507babf9..90456ffc89 100644 --- a/router/core/flushwriter.go +++ b/router/core/flushwriter.go @@ -53,7 +53,7 @@ func (f *HttpFlushWriter) Complete() { return } if f.sse { - _, _ = f.writer.Write([]byte("event: complete")) + _, _ = f.writer.Write([]byte("event: complete\ndata: \n\n")) } else if f.multipart { // Write the final boundary in the multipart response if f.apolloSubscriptionMultipartPrintBoundary { @@ -77,6 +77,33 @@ func (f *HttpFlushWriter) Write(p []byte) (n int, err error) { return f.buf.Write(p) } +func (f *HttpFlushWriter) Heartbeat() error { + if err := f.ctx.Err(); err != nil { + return err + } + + var heartbeat []byte + if f.sse { + heartbeat = []byte(":heartbeat\n\n") + + if _, err := f.writer.Write(heartbeat); err != nil { + return err + } + + f.flusher.Flush() + } else if f.multipart { + if _, err := f.Write([]byte("{}")); err != nil { + return err + } + + if err := f.Flush(); err != nil { + return err + } + } + + return nil +} + func (f *HttpFlushWriter) Close(_ resolve.SubscriptionCloseKind) { if f.ctx.Err() != nil { return @@ -159,7 +186,7 @@ func GetSubscriptionResponseWriter(ctx *resolve.Context, r *http.Request, w http flushWriter.ctx, flushWriter.cancel = context.WithCancel(ctx.Context()) ctx = ctx.WithContext(flushWriter.ctx) - if wgParams.UseMultipart { + if wgParams.UseMultipart || wgParams.UseSse { ctx.ExecutionOptions.SendHeartbeat = true } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index da70696ca8..380b8f0573 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1203,7 +1203,7 @@ func (s *graphServer) buildGraphMux( Reporter: s.engineStats, ApolloCompatibilityFlags: s.apolloCompatibilityFlags, ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, - HeartbeatInterval: s.multipartHeartbeatInterval, + HeartbeatInterval: s.subscriptionHeartbeatInterval, PluginsEnabled: s.plugins.Enabled, InstanceData: s.instanceData, }, diff --git a/router/core/router.go b/router/core/router.go index 7a5c3f6102..e5388aeeb7 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -1541,10 +1541,10 @@ func WithCors(corsOpts *cors.Config) Option { } } -// WithMultipartHeartbeatInterval sets the interval for the engine to send heartbeats for multipart subscriptions. -func WithMultipartHeartbeatInterval(interval time.Duration) Option { +// WithSubscriptionHeartbeatInterval sets the interval for the engine to send heartbeats for multipart subscriptions. +func WithSubscriptionHeartbeatInterval(interval time.Duration) Option { return func(r *Router) { - r.multipartHeartbeatInterval = interval + r.subscriptionHeartbeatInterval = interval } } diff --git a/router/core/router_config.go b/router/core/router_config.go index 0b70b1a72c..2a2a78924d 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -106,18 +106,18 @@ type Config struct { // should be removed once the users have migrated to the new overrides config overrideRoutingURLConfiguration config.OverrideRoutingURLConfiguration // the new overrides config - overrides config.OverridesConfiguration - authorization *config.AuthorizationConfiguration - rateLimit *config.RateLimitConfiguration - webSocketConfiguration *config.WebSocketConfiguration - subgraphErrorPropagation config.SubgraphErrorPropagationConfiguration - clientHeader config.ClientHeader - cacheWarmup *config.CacheWarmupConfiguration - multipartHeartbeatInterval time.Duration - hostName string - mcp config.MCPConfiguration - plugins config.PluginsConfiguration - tracingAttributes []config.CustomAttribute + overrides config.OverridesConfiguration + authorization *config.AuthorizationConfiguration + rateLimit *config.RateLimitConfiguration + webSocketConfiguration *config.WebSocketConfiguration + subgraphErrorPropagation config.SubgraphErrorPropagationConfiguration + clientHeader config.ClientHeader + cacheWarmup *config.CacheWarmupConfiguration + subscriptionHeartbeatInterval time.Duration + hostName string + mcp config.MCPConfiguration + plugins config.PluginsConfiguration + tracingAttributes []config.CustomAttribute } // Usage returns an anonymized version of the config for usage tracking diff --git a/router/core/websocket.go b/router/core/websocket.go index d1942e97c9..c0dcc78518 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -627,6 +627,11 @@ func (rw *websocketResponseWriter) Complete() { } } +// Heartbeat is a no-op function for WebSocket subscriptions. +func (rw *websocketResponseWriter) Heartbeat() error { + return nil +} + func (rw *websocketResponseWriter) Close(kind resolve.SubscriptionCloseKind) { err := rw.protocol.Close(kind.WSCode, kind.Reason) if err != nil { diff --git a/router/go.mod b/router/go.mod index 402ab2d326..4954a40fe1 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223 // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index 0f23a8fce4..a72296f7de 100644 --- a/router/go.sum +++ b/router/go.sum @@ -317,8 +317,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220 h1:+imPYcv+XExZ+ofX5jCxtaA7upeys7uWA7RsTZiTTWE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220/go.mod h1:DnYY1alnsgzkanSwbFiFIdXKOuf8dHQWQ2P4BzTc6aI= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223 h1:PUYcDoqkgqDnZVpO0c+y80kR308OQBtFzRPPegr0bIk= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223/go.mod h1:DnYY1alnsgzkanSwbFiFIdXKOuf8dHQWQ2P4BzTc6aI= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=