Skip to content
4 changes: 2 additions & 2 deletions execution/subscription/legacy_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func TestHandler_Handle(t *testing.T) {
time.Sleep(10 * time.Millisecond)
cancelFunc()

go sendChatMutation(t, chatServer.URL)
sendChatMutation(t, chatServer.URL)

require.Eventually(t, func() bool {
return client.hasMoreMessagesThan(0)
Expand Down Expand Up @@ -481,7 +481,7 @@ func TestHandler_Handle(t *testing.T) {
time.Sleep(10 * time.Millisecond)
cancelFunc()

go sendChatMutation(t, chatServer.URL)
sendChatMutation(t, chatServer.URL)

require.Eventually(t, func() bool {
return client.hasMoreMessagesThan(0)
Expand Down
21 changes: 12 additions & 9 deletions v2/pkg/engine/resolve/inbound_request_singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package resolve
import (
"encoding/binary"
"sync"
"sync/atomic"

"github.com/wundergraph/graphql-go-tools/v2/pkg/pool"
)
Expand Down Expand Up @@ -46,8 +47,15 @@ type InflightRequest struct {
Err error
ID uint64

HasFollowers bool
Mu sync.Mutex
followerCount atomic.Int32
}

func (r *InflightRequest) AddFollower() {
r.followerCount.Add(1)
}

func (r *InflightRequest) HasFollowers() bool {
return r.followerCount.Load() > 0
}

// GetOrCreate creates a new InflightRequest or returns an existing (shared) one
Expand Down Expand Up @@ -90,9 +98,7 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL
inflight, shared := shard.m.LoadOrStore(key, request)
if shared {
request = inflight.(*InflightRequest)
request.Mu.Lock()
request.HasFollowers = true
request.Mu.Unlock()
request.AddFollower()
select {
case <-request.Done:
if request.Err != nil {
Expand All @@ -113,10 +119,7 @@ func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte)
}
shard := r.shardFor(req.ID)
shard.m.Delete(req.ID)
req.Mu.Lock()
hasFollowers := req.HasFollowers
req.Mu.Unlock()
if hasFollowers {
if req.HasFollowers() {
// optimization to only copy when we actually have to
req.Data = make([]byte, len(data))
copy(req.Data, data)
Expand Down
25 changes: 10 additions & 15 deletions v2/pkg/engine/resolve/inbound_request_singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"sync"
"testing"
"time"

"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
)
Expand Down Expand Up @@ -75,8 +76,7 @@ func TestInboundSingleFlight_FollowerReceivesLeaderError(t *testing.T) {
}

// The follower calls GetOrCreate which blocks on inflight.Done.
// We wait for HasFollowers to be set before calling FinishErr.
followerReady := make(chan struct{})
// We wait for followerCount to confirm it has entered before calling FinishErr.
var wg sync.WaitGroup
wg.Add(1)

Expand All @@ -85,25 +85,20 @@ func TestInboundSingleFlight_FollowerReceivesLeaderError(t *testing.T) {
followerCtx := NewContext(context.Background())
followerCtx.Request.ID = 2

// Signal that we're about to enter GetOrCreate. HasFollowers will be
// set inside GetOrCreate before the select blocks, so closing
// followerReady here is slightly early, but we poll HasFollowers below.
close(followerReady)

_, followerErr := sf.GetOrCreate(followerCtx, response)
if followerErr == nil {
t.Error("expected error from follower after leader FinishErr")
}
}()

<-followerReady
// Spin until the follower has actually registered (set HasFollowers)
for {
inflight.Mu.Lock()
ready := inflight.HasFollowers
inflight.Mu.Unlock()
if ready {
break
// Poll until the follower has actually registered inside GetOrCreate.
deadline := time.After(3 * time.Second)
for !inflight.HasFollowers() {
select {
case <-deadline:
t.Fatal("timeout waiting for follower to enter singleflight")
default:
time.Sleep(10 * time.Millisecond)
}
}

Expand Down
82 changes: 54 additions & 28 deletions v2/pkg/engine/resolve/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,40 @@ func (w *blockingWriter) String() string {
return w.buf.String()
}

// findAnyInflight iterates through all singleflight shards and returns
// the first inflight request found. Used in tests to poll followerCount.
func findAnyInflight(r *Resolver) *InflightRequest {
for i := range r.inboundRequestSingleFlight.shards {
var found *InflightRequest
r.inboundRequestSingleFlight.shards[i].m.Range(func(_, value any) bool {
found = value.(*InflightRequest)
return false
})
if found != nil {
return found
}
}
return nil
}

// waitForFollowerCount polls until the inflight request has at least count followers registered.
func waitForFollowerCount(t *testing.T, r *Resolver, count int32) {
t.Helper()
deadline := time.After(3 * time.Second)
for {
inflight := findAnyInflight(r)
if inflight != nil && inflight.followerCount.Load() >= count {
return
}
select {
case <-deadline:
t.Fatal("timeout waiting for followers to enter singleflight")
default:
time.Sleep(10 * time.Millisecond)
}
}
}

type TestErrorWriter struct {
}

Expand Down Expand Up @@ -4694,30 +4728,20 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication(t *testing.T)
t.Fatalf("timeout waiting for leader data source load")
}

startFollowers := make(chan struct{})
followersEntered := make(chan struct{}, requestCount-1)

for i := 1; i < requestCount; i++ {
go func(i int) {
defer wg.Done()
ctx := ctxTemplate
<-startFollowers
followersEntered <- struct{}{}
buf := &bytes.Buffer{}
info, err := r.ArenaResolveGraphQLResponse(&ctx, response, buf)
results[i] = result{info: info, output: buf.String(), err: err}
}(i)
}

close(startFollowers)

for i := 1; i < requestCount; i++ {
select {
case <-followersEntered:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for follower %d to start", i)
}
}
// Wait until all followers have entered the singleflight (called LoadOrStore)
// before releasing the data source. This guarantees they join the leader's
// inflight request rather than creating their own.
waitForFollowerCount(t, r, int32(requestCount-1))

ds.Release()

Expand Down Expand Up @@ -4823,9 +4847,6 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication_SharedData(t
t.Fatalf("timeout waiting for leader data source load")
}

startFollowers := make(chan struct{})
followersEntered := make(chan struct{}, requestCount-1)

for i := 1; i < requestCount; i++ {
go func(i int) {
defer wg.Done()
Expand All @@ -4838,23 +4859,16 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication_SharedData(t
followerData.Store(i, data)
},
)
<-startFollowers
followersEntered <- struct{}{}
buf := &bytes.Buffer{}
info, err := r.ArenaResolveGraphQLResponse(&ctx, response, buf)
results[i] = result{info: info, output: buf.String(), err: err}
}(i)
}

close(startFollowers)

for i := 1; i < requestCount; i++ {
select {
case <-followersEntered:
case <-time.After(time.Second):
t.Fatalf("timeout waiting for follower %d to start", i)
}
}
// Wait until all followers have entered the singleflight (called LoadOrStore)
// before releasing the data source. This guarantees they join the leader's
// inflight request rather than creating their own.
waitForFollowerCount(t, r, int32(requestCount-1))

ds.Release()

Expand Down Expand Up @@ -6426,10 +6440,21 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {
c, cancel := context.WithCancel(context.Background())
defer cancel()

// sub2Ready gates the data source goroutine so that it doesn't start
// emitting before sub2 has been registered on the trigger. Without this,
// the emitting goroutine's first triggerUpdate can race sub2's
// addSubscription on the unbuffered events channel, causing sub2 to
// miss counter=0.
sub2Ready := make(chan struct{})
fakeStream := createFakeStream(func(counter int) (message string, done bool) {
return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 100
}, 1*time.Millisecond, func(input []byte) {
assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input))
// Block the data source goroutine until sub2 is registered.
// onStart runs inside the goroutine that calls Start(), not the
// event loop, so blocking here is safe — the event loop remains
// free to process sub2's addSubscription event.
<-sub2Ready
}, func(ctx StartupHookContext, input []byte) (err error) {
return nil
})
Expand All @@ -6451,6 +6476,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) {

err2 := resolver1.AsyncResolveGraphQLSubscription(ctx2, plan1, recorder2, id2)
assert.NoError(t, err2)
close(sub2Ready)

// complete is called only on the last recorder
recorder1.AwaitComplete(t, defaultTimeout)
Expand Down