From a4ab0b0ecf5a4083e030c5eac949d9e08323d5e1 Mon Sep 17 00:00:00 2001 From: Arthur Schreiber Date: Mon, 14 Jul 2025 10:32:32 +0000 Subject: [PATCH] Fix race conditions in the connection pool code. Signed-off-by: Arthur Schreiber --- go/pools/smartconnpool/pool.go | 52 ++++++++++++++++++------- go/pools/smartconnpool/waitlist.go | 8 +++- go/pools/smartconnpool/waitlist_test.go | 6 ++- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go index 3914bda943b..f6d6d73d091 100644 --- a/go/pools/smartconnpool/pool.go +++ b/go/pools/smartconnpool/pool.go @@ -129,7 +129,7 @@ type ConnPool[C Connection] struct { // workers is a waitgroup for all the currently running worker goroutines workers sync.WaitGroup close chan struct{} - capacityMu sync.Mutex + capacityMu sync.RWMutex config struct { // connect is the callback to create a new connection for the pool @@ -430,18 +430,40 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool { if pool.wait.tryReturnConn(conn) { return true } - if pool.closeOnIdleLimitReached(conn) { + + for { + if pool.capacity.Load() < pool.active.Load() { + conn.Close() + pool.closedConn() + return true + } + + if pool.closeOnIdleLimitReached(conn) { + return false + } + + if !pool.capacityMu.TryRLock() { + // If we can't get a read lock here, it means that the pool is being closed. Retry and check `capacity` again. + continue + } + defer pool.capacityMu.RUnlock() + + if pool.capacity.Load() < pool.active.Load() { + conn.Close() + pool.closedConn() + return true + } + + connSetting := conn.Conn.Setting() + if connSetting == nil { + pool.clean.Push(conn) + } else { + stack := connSetting.bucket & stackMask + pool.settings[stack].Push(conn) + pool.freshSettingsStack.Store(int64(stack)) + } return false } - connSetting := conn.Conn.Setting() - if connSetting == nil { - pool.clean.Push(conn) - } else { - stack := connSetting.bucket & stackMask - pool.settings[stack].Push(conn) - pool.freshSettingsStack.Store(int64(stack)) - } - return false } func (pool *ConnPool[C]) pop(stack *connStack[C]) *Pooled[C] { @@ -595,7 +617,9 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { // to other clients, wait until one of the connections is returned if conn == nil { start := time.Now() - conn, err = pool.wait.waitForConn(ctx, nil) + conn, err = pool.wait.waitForConn(ctx, nil, func() bool { + return pool.close == nil || pool.capacity.Load() == 0 + }) if err != nil { return nil, ErrTimeout } @@ -652,7 +676,9 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( // wait for one of them if conn == nil { start := time.Now() - conn, err = pool.wait.waitForConn(ctx, setting) + conn, err = pool.wait.waitForConn(ctx, setting, func() bool { + return pool.close == nil || pool.capacity.Load() == 0 + }) if err != nil { return nil, ErrTimeout } diff --git a/go/pools/smartconnpool/waitlist.go b/go/pools/smartconnpool/waitlist.go index ef1eb1fe997..9ee19a246ed 100644 --- a/go/pools/smartconnpool/waitlist.go +++ b/go/pools/smartconnpool/waitlist.go @@ -50,11 +50,17 @@ type waitlist[C Connection] struct { // The returned connection may _not_ have the requested Setting. This function can // also return a `nil` connection even if our context has expired, if the pool has // forced an expiration of all waiters in the waitlist. -func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting) (*Pooled[C], error) { +func (wl *waitlist[C]) waitForConn(ctx context.Context, setting *Setting, isClosed func() bool) (*Pooled[C], error) { elem := wl.nodes.Get().(*list.Element[waiter[C]]) elem.Value = waiter[C]{setting: setting, conn: nil, ctx: ctx} wl.mu.Lock() + if isClosed() { + // if the pool is closed, we can't wait for a connection, so return an error + wl.nodes.Put(elem) + wl.mu.Unlock() + return nil, ErrConnPoolClosed + } // add ourselves as a waiter at the end of the waitlist wl.list.PushBackValue(elem) wl.mu.Unlock() diff --git a/go/pools/smartconnpool/waitlist_test.go b/go/pools/smartconnpool/waitlist_test.go index 1486aa989b6..49b5af518d3 100644 --- a/go/pools/smartconnpool/waitlist_test.go +++ b/go/pools/smartconnpool/waitlist_test.go @@ -38,7 +38,11 @@ func TestWaitlistExpireWithMultipleWaiters(t *testing.T) { for i := 0; i < waiterCount; i++ { go func() { - _, err := wait.waitForConn(ctx, nil) + _, err := wait.waitForConn(ctx, nil, func() bool { + // This function is called to check if the pool is closed. + return ctx.Err() != nil + }) + if err != nil { expireCount.Add(1) }