diff --git a/execution/subscription/legacy_handler_test.go b/execution/subscription/legacy_handler_test.go index 8ccb631711..1ca4933258 100644 --- a/execution/subscription/legacy_handler_test.go +++ b/execution/subscription/legacy_handler_test.go @@ -417,7 +417,7 @@ func TestHandler_Handle(t *testing.T) { time.Sleep(10 * time.Millisecond) cancelFunc() - go sendChatMutation(t, chatServer.URL) + sendChatMutation(t, chatServer.URL) require.Eventually(t, func() bool { return client.hasMoreMessagesThan(0) @@ -481,7 +481,7 @@ func TestHandler_Handle(t *testing.T) { time.Sleep(10 * time.Millisecond) cancelFunc() - go sendChatMutation(t, chatServer.URL) + sendChatMutation(t, chatServer.URL) require.Eventually(t, func() bool { return client.hasMoreMessagesThan(0) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index a67d8deccc..5affdb4d36 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -3,6 +3,7 @@ package resolve import ( "encoding/binary" "sync" + "sync/atomic" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) @@ -46,8 +47,15 @@ type InflightRequest struct { Err error ID uint64 - HasFollowers bool - Mu sync.Mutex + followerCount atomic.Int32 +} + +func (r *InflightRequest) AddFollower() { + r.followerCount.Add(1) +} + +func (r *InflightRequest) HasFollowers() bool { + return r.followerCount.Load() > 0 } // GetOrCreate creates a new InflightRequest or returns an existing (shared) one @@ -90,9 +98,7 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL inflight, shared := shard.m.LoadOrStore(key, request) if shared { request = inflight.(*InflightRequest) - request.Mu.Lock() - request.HasFollowers = true - request.Mu.Unlock() + request.AddFollower() select { case <-request.Done: if request.Err != nil { @@ -113,10 +119,7 @@ func (r *InboundRequestSingleFlight) FinishOk(req *InflightRequest, data []byte) } shard := r.shardFor(req.ID) shard.m.Delete(req.ID) - req.Mu.Lock() - hasFollowers := req.HasFollowers - req.Mu.Unlock() - if hasFollowers { + if req.HasFollowers() { // optimization to only copy when we actually have to req.Data = make([]byte, len(data)) copy(req.Data, data) diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight_test.go b/v2/pkg/engine/resolve/inbound_request_singleflight_test.go index a82d372d4f..8198b8723d 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight_test.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight_test.go @@ -4,6 +4,7 @@ import ( "context" "sync" "testing" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -75,8 +76,7 @@ func TestInboundSingleFlight_FollowerReceivesLeaderError(t *testing.T) { } // The follower calls GetOrCreate which blocks on inflight.Done. - // We wait for HasFollowers to be set before calling FinishErr. - followerReady := make(chan struct{}) + // We wait for followerCount to confirm it has entered before calling FinishErr. var wg sync.WaitGroup wg.Add(1) @@ -85,25 +85,20 @@ func TestInboundSingleFlight_FollowerReceivesLeaderError(t *testing.T) { followerCtx := NewContext(context.Background()) followerCtx.Request.ID = 2 - // Signal that we're about to enter GetOrCreate. HasFollowers will be - // set inside GetOrCreate before the select blocks, so closing - // followerReady here is slightly early, but we poll HasFollowers below. - close(followerReady) - _, followerErr := sf.GetOrCreate(followerCtx, response) if followerErr == nil { t.Error("expected error from follower after leader FinishErr") } }() - <-followerReady - // Spin until the follower has actually registered (set HasFollowers) - for { - inflight.Mu.Lock() - ready := inflight.HasFollowers - inflight.Mu.Unlock() - if ready { - break + // Poll until the follower has actually registered inside GetOrCreate. + deadline := time.After(3 * time.Second) + for !inflight.HasFollowers() { + select { + case <-deadline: + t.Fatal("timeout waiting for follower to enter singleflight") + default: + time.Sleep(10 * time.Millisecond) } } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 287ab72a63..56619c9035 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -149,6 +149,40 @@ func (w *blockingWriter) String() string { return w.buf.String() } +// findAnyInflight iterates through all singleflight shards and returns +// the first inflight request found. Used in tests to poll followerCount. +func findAnyInflight(r *Resolver) *InflightRequest { + for i := range r.inboundRequestSingleFlight.shards { + var found *InflightRequest + r.inboundRequestSingleFlight.shards[i].m.Range(func(_, value any) bool { + found = value.(*InflightRequest) + return false + }) + if found != nil { + return found + } + } + return nil +} + +// waitForFollowerCount polls until the inflight request has at least count followers registered. +func waitForFollowerCount(t *testing.T, r *Resolver, count int32) { + t.Helper() + deadline := time.After(3 * time.Second) + for { + inflight := findAnyInflight(r) + if inflight != nil && inflight.followerCount.Load() >= count { + return + } + select { + case <-deadline: + t.Fatal("timeout waiting for followers to enter singleflight") + default: + time.Sleep(10 * time.Millisecond) + } + } +} + type TestErrorWriter struct { } @@ -4694,30 +4728,20 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication(t *testing.T) t.Fatalf("timeout waiting for leader data source load") } - startFollowers := make(chan struct{}) - followersEntered := make(chan struct{}, requestCount-1) - for i := 1; i < requestCount; i++ { go func(i int) { defer wg.Done() ctx := ctxTemplate - <-startFollowers - followersEntered <- struct{}{} buf := &bytes.Buffer{} info, err := r.ArenaResolveGraphQLResponse(&ctx, response, buf) results[i] = result{info: info, output: buf.String(), err: err} }(i) } - close(startFollowers) - - for i := 1; i < requestCount; i++ { - select { - case <-followersEntered: - case <-time.After(time.Second): - t.Fatalf("timeout waiting for follower %d to start", i) - } - } + // Wait until all followers have entered the singleflight (called LoadOrStore) + // before releasing the data source. This guarantees they join the leader's + // inflight request rather than creating their own. + waitForFollowerCount(t, r, int32(requestCount-1)) ds.Release() @@ -4823,9 +4847,6 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication_SharedData(t t.Fatalf("timeout waiting for leader data source load") } - startFollowers := make(chan struct{}) - followersEntered := make(chan struct{}, requestCount-1) - for i := 1; i < requestCount; i++ { go func(i int) { defer wg.Done() @@ -4838,23 +4859,16 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication_SharedData(t followerData.Store(i, data) }, ) - <-startFollowers - followersEntered <- struct{}{} buf := &bytes.Buffer{} info, err := r.ArenaResolveGraphQLResponse(&ctx, response, buf) results[i] = result{info: info, output: buf.String(), err: err} }(i) } - close(startFollowers) - - for i := 1; i < requestCount; i++ { - select { - case <-followersEntered: - case <-time.After(time.Second): - t.Fatalf("timeout waiting for follower %d to start", i) - } - } + // Wait until all followers have entered the singleflight (called LoadOrStore) + // before releasing the data source. This guarantees they join the leader's + // inflight request rather than creating their own. + waitForFollowerCount(t, r, int32(requestCount-1)) ds.Release() @@ -6426,10 +6440,21 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { c, cancel := context.WithCancel(context.Background()) defer cancel() + // sub2Ready gates the data source goroutine so that it doesn't start + // emitting before sub2 has been registered on the trigger. Without this, + // the emitting goroutine's first triggerUpdate can race sub2's + // addSubscription on the unbuffered events channel, causing sub2 to + // miss counter=0. + sub2Ready := make(chan struct{}) fakeStream := createFakeStream(func(counter int) (message string, done bool) { return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), counter == 100 }, 1*time.Millisecond, func(input []byte) { assert.Equal(t, `{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`, string(input)) + // Block the data source goroutine until sub2 is registered. + // onStart runs inside the goroutine that calls Start(), not the + // event loop, so blocking here is safe — the event loop remains + // free to process sub2's addSubscription event. + <-sub2Ready }, func(ctx StartupHookContext, input []byte) (err error) { return nil }) @@ -6451,6 +6476,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err2 := resolver1.AsyncResolveGraphQLSubscription(ctx2, plan1, recorder2, id2) assert.NoError(t, err2) + close(sub2Ready) // complete is called only on the last recorder recorder1.AwaitComplete(t, defaultTimeout)