Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions router-tests/events/kafka_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func TestKafkaEvents(t *testing.T) {
t.Run("multipart", func(t *testing.T) {
t.Parallel()

multipartHeartbeatInterval := time.Second * 5
subscriptionHeartbeatInterval := time.Second * 5

t.Run("subscribe sync", func(t *testing.T) {
t.Parallel()
Expand All @@ -428,7 +428,7 @@ func TestKafkaEvents(t *testing.T) {
RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate,
EnableKafka: true,
RouterOptions: []core.Option{
core.WithMultipartHeartbeatInterval(multipartHeartbeatInterval),
core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
EnsureTopicExists(t, xEnv, topics...)
Expand Down
4 changes: 2 additions & 2 deletions router-tests/events/nats_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func TestNatsEvents(t *testing.T) {
testenv.Run(t, &testenv.Config{
RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate,
RouterOptions: []core.Option{
core.WithMultipartHeartbeatInterval(heartbeatInterval),
core.WithSubscriptionHeartbeatInterval(heartbeatInterval),
},
EnableNats: true,
TLSConfig: &core.TlsConfig{
Expand Down Expand Up @@ -378,7 +378,7 @@ func TestNatsEvents(t *testing.T) {
EnableNats: true,
TLSConfig: nil, // Force Http/1
RouterOptions: []core.Option{
core.WithMultipartHeartbeatInterval(heartbeatInterval),
core.WithSubscriptionHeartbeatInterval(heartbeatInterval),
},
}, func(t *testing.T, xEnv *testenv.Environment) {

Expand Down
4 changes: 2 additions & 2 deletions router-tests/events/redis_events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func TestRedisEvents(t *testing.T) {
t.Run("multipart", func(t *testing.T) {
t.Parallel()

multipartHeartbeatInterval := time.Second * 5
subscriptionHeartbeatInterval := time.Second * 5

t.Run("subscribe sync", func(t *testing.T) {
t.Parallel()
Expand All @@ -489,7 +489,7 @@ func TestRedisEvents(t *testing.T) {
RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate,
EnableRedis: true,
RouterOptions: []core.Option{
core.WithMultipartHeartbeatInterval(multipartHeartbeatInterval),
core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
subscribePayload := []byte(`{"query":"subscription { employeeUpdates { id details { forename surname } }}"}`)
Expand Down
2 changes: 1 addition & 1 deletion router-tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ require (
github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e
github.com/wundergraph/cosmo/router v0.0.0-20250820135159-bf8852195d3f
github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223
go.opentelemetry.io/otel v1.36.0
go.opentelemetry.io/otel/sdk v1.36.0
go.opentelemetry.io/otel/sdk/metric v1.36.0
Expand Down
4 changes: 2 additions & 2 deletions router-tests/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB
github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE=
github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0=
github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0=
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220 h1:+imPYcv+XExZ+ofX5jCxtaA7upeys7uWA7RsTZiTTWE=
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.220/go.mod h1:DnYY1alnsgzkanSwbFiFIdXKOuf8dHQWQ2P4BzTc6aI=
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223 h1:PUYcDoqkgqDnZVpO0c+y80kR308OQBtFzRPPegr0bIk=
github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.223/go.mod h1:DnYY1alnsgzkanSwbFiFIdXKOuf8dHQWQ2P4BzTc6aI=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
Expand Down
201 changes: 201 additions & 0 deletions router-tests/http_subscriptions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package integration

import (
"bufio"
"bytes"
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
)

func readMultipartPrefix(reader *bufio.Reader) error {
blankHeader, _, err := reader.ReadLine()
if err != nil {
return err
}

if len(blankHeader) != 0 {
return fmt.Errorf("expected blank header, got %q", blankHeader)
}

graphQLHeader, _, err := reader.ReadLine()
if err != nil {
return err
}

if string(graphQLHeader) != "--graphql" {
return fmt.Errorf("expected graphql header, got %q", graphQLHeader)
}

contentTypeHeader, _, err := reader.ReadLine()
if err != nil {
return err
}

if string(contentTypeHeader) != "Content-Type: application/json" {
return fmt.Errorf("expected content type header, got %q", contentTypeHeader)
}

blankFooter, _, err := reader.ReadLine()
if err != nil {
return err
}

if len(blankFooter) != 0 {
return fmt.Errorf("expected blank footer, got %q", blankFooter)
}

return nil
}
Comment thread
endigma marked this conversation as resolved.

func TestHeartbeats(t *testing.T) {
subscriptionHeartbeatInterval := time.Millisecond * 300

t.Run("should work correctly for multipart", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
client := http.Client{
Timeout: time.Second * 100,
}

subscribePayload := []byte(`{"query":"subscription { countEmp(max: 5, intervalMilliseconds: 550) }"}`)

req := xEnv.MakeGraphQLMultipartRequest(http.MethodPost, bytes.NewReader(subscribePayload))
resp, gErr := client.Do(req)
require.NoError(t, gErr)
require.Equal(t, http.StatusOK, resp.StatusCode)

defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)

messages := make(chan string, 1)

go func() {
defer close(messages)
for {
err := readMultipartPrefix(reader)
if err != nil {
return
}

line, _, err := reader.ReadLine()
if err != nil {
return
}

fmt.Println(string(line))
messages <- string(line)
}
}()

for i := 0; i <= 5; i++ {
testenv.AwaitChannelWithT(t, 5*time.Second, messages, func(t *testing.T, msg string) {
assert.Equal(t, fmt.Sprintf(`{"payload":{"data":{"countEmp":%d}}}`, i), msg)
})

testenv.AwaitChannelWithT(t, 5*time.Second, messages, func(t *testing.T, msg string) {
assert.Equal(t, `{}`, msg)
})
}
Comment thread
endigma marked this conversation as resolved.

// Channel should be closed after all heartbeats are received
testenv.AwaitChannelWithCloseWithT(t, 5*time.Second, messages, func(t *testing.T, _ string, ok bool) {
require.False(t, ok, "channel should be closed")
})
})
})

t.Run("should work correctly for sse", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
client := http.Client{
Timeout: time.Second * 100,
}

subscribePayload := []byte(`{"query":"subscription { countEmp(max: 5, intervalMilliseconds: 550) }"}`)

req, err := http.NewRequest(http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(subscribePayload))
require.NoError(t, err)

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Cache-Control", "no-cache")

resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)

lines := make(chan string, 50)

go func() {
defer close(lines)
for {
line, _, err := reader.ReadLine()
if err != nil {
return
}
lines <- string(line)
}
}()

// Assert the expected SSE sequence
for i := 0; i <= 5; i++ {
// Expect "event: next"
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, "event: next", line)
})

// Expect data line with count
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, fmt.Sprintf(`data: {"data":{"countEmp":%d}}`, i), line)
})

// Expect blank line
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, "", line)
})

// Expect heartbeat
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, ":heartbeat", line)
})

// Expect blank line after heartbeat
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, "", line)
})
}

// Expect completion event
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, "event: complete", line)
})

// Expect empty data line event
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, "data: ", line)
})

// Expect blank line after complete
testenv.AwaitChannelWithT(t, 5*time.Second, lines, func(t *testing.T, line string) {
assert.Equal(t, "", line)
})
})
})
}
13 changes: 12 additions & 1 deletion router-tests/testenv/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/require"
)

func AwaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(*testing.T, A), msgAndArgs ...interface{}) {
func AwaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(*testing.T, A), msgAndArgs ...any) {
t.Helper()

select {
Expand All @@ -17,3 +17,14 @@ func AwaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A,
require.Fail(t, "unable to receive message before timeout", msgAndArgs...)
}
}

func AwaitChannelWithCloseWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(t *testing.T, item A, ok bool), msgAndArgs ...any) {
t.Helper()

select {
case args, ok := <-ch:
f(t, args, ok)
case <-time.After(timeout):
require.Fail(t, "unable to receive message before timeout", msgAndArgs...)
}
}
2 changes: 1 addition & 1 deletion router/core/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor
AllowedSubgraphErrorFields: opts.RouterEngineConfig.SubgraphErrorPropagation.AllowedFields,
AllowAllErrorExtensionFields: opts.RouterEngineConfig.SubgraphErrorPropagation.AllowAllExtensionFields,
MaxRecyclableParserSize: opts.RouterEngineConfig.Execution.ResolverMaxRecyclableParserSize,
MultipartSubHeartbeatInterval: opts.HeartbeatInterval,
SubscriptionHeartbeatInterval: opts.HeartbeatInterval,
MaxSubscriptionFetchTimeout: opts.RouterEngineConfig.Execution.SubscriptionFetchTimeout,
}

Expand Down
31 changes: 29 additions & 2 deletions router/core/flushwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (f *HttpFlushWriter) Complete() {
return
}
if f.sse {
_, _ = f.writer.Write([]byte("event: complete"))
_, _ = f.writer.Write([]byte("event: complete\ndata: \n\n"))
} else if f.multipart {
// Write the final boundary in the multipart response
if f.apolloSubscriptionMultipartPrintBoundary {
Expand All @@ -77,6 +77,33 @@ func (f *HttpFlushWriter) Write(p []byte) (n int, err error) {
return f.buf.Write(p)
}

func (f *HttpFlushWriter) Heartbeat() error {
Comment thread
StarpTech marked this conversation as resolved.
if err := f.ctx.Err(); err != nil {
return err
}

var heartbeat []byte
if f.sse {
heartbeat = []byte(":heartbeat\n\n")

if _, err := f.writer.Write(heartbeat); err != nil {
return err
}

f.flusher.Flush()
} else if f.multipart {
if _, err := f.Write([]byte("{}")); err != nil {
return err
}

if err := f.Flush(); err != nil {
return err
}
}

return nil
}

func (f *HttpFlushWriter) Close(_ resolve.SubscriptionCloseKind) {
if f.ctx.Err() != nil {
return
Expand Down Expand Up @@ -159,7 +186,7 @@ func GetSubscriptionResponseWriter(ctx *resolve.Context, r *http.Request, w http
flushWriter.ctx, flushWriter.cancel = context.WithCancel(ctx.Context())
ctx = ctx.WithContext(flushWriter.ctx)

if wgParams.UseMultipart {
if wgParams.UseMultipart || wgParams.UseSse {
ctx.ExecutionOptions.SendHeartbeat = true
}

Expand Down
2 changes: 1 addition & 1 deletion router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ func (s *graphServer) buildGraphMux(
Reporter: s.engineStats,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags,
HeartbeatInterval: s.multipartHeartbeatInterval,
HeartbeatInterval: s.subscriptionHeartbeatInterval,
PluginsEnabled: s.plugins.Enabled,
InstanceData: s.instanceData,
},
Expand Down
6 changes: 3 additions & 3 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -1541,10 +1541,10 @@ func WithCors(corsOpts *cors.Config) Option {
}
}

// WithMultipartHeartbeatInterval sets the interval for the engine to send heartbeats for multipart subscriptions.
func WithMultipartHeartbeatInterval(interval time.Duration) Option {
// WithSubscriptionHeartbeatInterval sets the interval for the engine to send heartbeats for multipart subscriptions.
func WithSubscriptionHeartbeatInterval(interval time.Duration) Option {
return func(r *Router) {
r.multipartHeartbeatInterval = interval
r.subscriptionHeartbeatInterval = interval
}
}

Expand Down
Loading
Loading