Skip to content
100 changes: 68 additions & 32 deletions v2/pkg/engine/resolve/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5526,26 +5526,38 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
defer cancel()

executed := atomic.Bool{}
subsStarted := sync.WaitGroup{}
subsStarted.Add(2)

id2 := SubscriptionIdentifier{
ConnectionID: 1,
SubscriptionID: 2,
}

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 1*time.Millisecond, func(input []byte) {
streamCanStart := make(chan struct{})
startupHookWaitGroup := sync.WaitGroup{}
startupHookWaitGroup.Add(2)

Comment thread
alepane21 marked this conversation as resolved.
// this message must come as last, on both recorders.
messageFn := func(counter int) (message string, done bool) {
<-streamCanStart
return `{"data":{"counter":0}}`, true
}

onStartFn := func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
}, func(ctx StartupHookContext, input []byte) (err error) {
}

// this message must come first on the first recorder to be added to the trigger.
subscriptionOnStartFn := func(ctx StartupHookContext, input []byte) (err error) {
defer startupHookWaitGroup.Done()
if executed.Load() {
return
}
executed.Store(true)
ctx.Updater([]byte(`{"data":{"counter":1000}}`))
return nil
})
}

fakeStream := createFakeStream(messageFn, time.Millisecond, onStartFn, subscriptionOnStartFn)
fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) {
return nil
}
Expand Down Expand Up @@ -5575,20 +5587,35 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
assert.NoError(t, err)
subsStarted.Done()

err2 := resolver.AsyncResolveGraphQLSubscription(ctx2, plan, recorder2, id2)
assert.NoError(t, err2)
subsStarted.Done()

done := make(chan struct{})
go func() {
startupHookWaitGroup.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(defaultTimeout):
t.Fatal("timed out waiting for subscription startup hooks")
}

// Wait for both subscriptions startup hooks to be executed
startupHookWaitGroup.Wait()

// Signal the stream to send its message now that both subscriptions are ready
close(streamCanStart)

recorder.AwaitComplete(t, defaultTimeout)
recorder2.AwaitComplete(t, defaultTimeout)

recorders := []*SubscriptionRecorder{recorder, recorder2}

recorderWith1Message := false
recorderWith2Messages := false

recorders := []*SubscriptionRecorder{recorder, recorder2}

for _, r := range recorders {
if len(r.Messages()) == 2 {
recorderWith2Messages = true
Expand All @@ -5601,30 +5628,38 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
}
}

assert.True(t, recorderWith1Message)
assert.True(t, recorderWith2Messages)
assert.True(t, recorderWith1Message, "recorder 1: %v, recorder 2: %v", recorder.messages, recorder2.messages)
assert.True(t, recorderWith2Messages, "recorder 1: %v, recorder 2: %v", recorder.messages, recorder2.messages)
})

t.Run("SubscriptionOnStart ctx updater on multiple subscriptions with same trigger works", func(t *testing.T) {
c, cancel := context.WithCancel(context.Background())
defer cancel()

subsStarted := sync.WaitGroup{}
subsStarted.Add(2)

id2 := SubscriptionIdentifier{
ConnectionID: 1,
SubscriptionID: 2,
}

fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 0
}, 1*time.Millisecond, func(input []byte) {
streamCanStart := make(chan struct{})

// Message function that waits for signal before sending the final message
messageFn := func(counter int) (message string, done bool) {
<-streamCanStart
return `{"data":{"counter":0}}`, true
}

onStartFn := func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
}, func(ctx StartupHookContext, input []byte) (err error) {
}

// this handler pushes the first message to both subscribers via subscription-on-start hook.
subscriptionOnStartFn := func(ctx StartupHookContext, input []byte) (err error) {
ctx.Updater([]byte(`{"data":{"counter":1000}}`))
return nil
})
}

fakeStream := createFakeStream(messageFn, 1*time.Millisecond, onStartFn, subscriptionOnStartFn)
fakeStream.uniqueRequestFn = func(ctx *Context, input []byte, xxh *xxhash.Digest) (err error) {
_, err = xxh.WriteString("unique")
return
Expand Down Expand Up @@ -5655,24 +5690,25 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id)
assert.NoError(t, err)
subsStarted.Done()

err2 := resolver.AsyncResolveGraphQLSubscription(ctx2, plan, recorder2, id2)
assert.NoError(t, err2)
subsStarted.Done()

// Wait for subscriptions to receive their initial message from the subscription-on-start hook.
// Then we know the subscriptions are fully registered on the trigger and then we can send
// the next message (by closing the streamCanStart channel).
recorder.AwaitAnyMessageCount(t, defaultTimeout)
recorder2.AwaitAnyMessageCount(t, defaultTimeout)
close(streamCanStart)

recorder.AwaitComplete(t, defaultTimeout)
recorder2.AwaitComplete(t, defaultTimeout)

recorders := []*SubscriptionRecorder{recorder, recorder2}

for _, r := range recorders {
if len(r.Messages()) == 2 {
assert.Equal(t, `{"data":{"counter":1000}}`, r.Messages()[0])
assert.Equal(t, `{"data":{"counter":0}}`, r.Messages()[1])
} else {
assert.Fail(t, "should not be here")
}
// Both recorders should have received both messages in the correct order.
for _, r := range []*SubscriptionRecorder{recorder, recorder2} {
assert.Len(t, r.Messages(), 2, "recorder messages: %v", r.messages)
assert.Equal(t, `{"data":{"counter":1000}}`, r.Messages()[0])
assert.Equal(t, `{"data":{"counter":0}}`, r.Messages()[1])
}
})

Expand Down