diff --git a/go/pools/smartconnpool/connection.go b/go/pools/smartconnpool/connection.go index cdb5720596e..dbc235a8218 100644 --- a/go/pools/smartconnpool/connection.go +++ b/go/pools/smartconnpool/connection.go @@ -19,7 +19,6 @@ package smartconnpool import ( "context" "sync/atomic" - "time" ) type Connection interface { @@ -33,8 +32,8 @@ type Connection interface { type Pooled[C Connection] struct { next atomic.Pointer[Pooled[C]] - timeCreated time.Time - timeUsed time.Time + timeCreated timestamp + timeUsed timestamp pool *ConnPool[C] Conn C diff --git a/go/pools/smartconnpool/pool.go b/go/pools/smartconnpool/pool.go index c75ad6c12df..ef69e9cafda 100644 --- a/go/pools/smartconnpool/pool.go +++ b/go/pools/smartconnpool/pool.go @@ -18,7 +18,6 @@ package smartconnpool import ( "context" - "slices" "sync" "sync/atomic" "time" @@ -134,7 +133,6 @@ type ConnPool[C Connection] struct { connect Connector[C] // refresh is the callback to check whether the pool needs to be refreshed refresh RefreshCheck - // maxCapacity is the maximum value to which capacity can be set; when the pool // is re-opened, it defaults to this capacity maxCapacity int64 @@ -381,19 +379,23 @@ func (pool *ConnPool[C]) put(conn *Pooled[C]) { if conn == nil { var err error + // Using context.Background() is fine since MySQL connection already enforces + // a connect timeout via the `db_connect_timeout_ms` config param. conn, err = pool.connNew(context.Background()) if err != nil { pool.closedConn() return } } else { - conn.timeUsed = time.Now() + conn.timeUsed.update() lifetime := pool.extendedMaxLifetime() - if lifetime > 0 && time.Until(conn.timeCreated.Add(lifetime)) < 0 { + if lifetime > 0 && conn.timeCreated.elapsed() > lifetime { pool.Metrics.maxLifetimeClosed.Add(1) conn.Close() - if err := pool.connReopen(context.Background(), conn, conn.timeUsed); err != nil { + // Using context.Background() is fine since MySQL connection already enforces + // a connect timeout via the `db_connect_timeout_ms` config param. + if err := pool.connReopen(context.Background(), conn, conn.timeUsed.get()); err != nil { pool.closedConn() return } @@ -418,12 +420,30 @@ func (pool *ConnPool[C]) tryReturnConn(conn *Pooled[C]) bool { return false } +func (pool *ConnPool[C]) pop(stack *connStack[C]) *Pooled[C] { + // retry-loop: pop a connection from the stack and atomically check whether + // its timeout has elapsed. If the timeout has elapsed, the borrow will fail, + // which means that a background worker has already marked this connection + // as stale and is in the process of shutting it down. If we successfully mark + // the timeout as borrowed, we know that background workers will not be able + // to expire this connection (even if it's still visible to them), so it's + // safe to return it + for conn, ok := stack.Pop(); ok; conn, ok = stack.Pop() { + if conn.timeUsed.borrow() { + return conn + } + } + return nil +} + func (pool *ConnPool[C]) tryReturnAnyConn() bool { - if conn, ok := pool.clean.Pop(); ok { + if conn := pool.pop(&pool.clean); conn != nil { + conn.timeUsed.update() return pool.tryReturnConn(conn) } for u := 0; u <= stackMask; u++ { - if conn, ok := pool.settings[u].Pop(); ok { + if conn := pool.pop(&pool.settings[u]); conn != nil { + conn.timeUsed.update() return pool.tryReturnConn(conn) } } @@ -439,15 +459,22 @@ func (pool *ConnPool[D]) extendedMaxLifetime() time.Duration { return time.Duration(maxLifetime) + time.Duration(extended) } -func (pool *ConnPool[C]) connReopen(ctx context.Context, dbconn *Pooled[C], now time.Time) error { - var err error +func (pool *ConnPool[C]) connReopen(ctx context.Context, dbconn *Pooled[C], now time.Duration) (err error) { dbconn.Conn, err = pool.config.connect(ctx) if err != nil { return err } - dbconn.timeUsed = now - dbconn.timeCreated = now + if setting := dbconn.Conn.Setting(); setting != nil { + err = dbconn.Conn.ApplySetting(ctx, setting) + if err != nil { + dbconn.Close() + return err + } + } + + dbconn.timeCreated.set(now) + dbconn.timeUsed.set(now) return nil } @@ -456,13 +483,14 @@ func (pool *ConnPool[C]) connNew(ctx context.Context) (*Pooled[C], error) { if err != nil { return nil, err } - now := time.Now() - return &Pooled[C]{ - timeCreated: now, - timeUsed: now, - pool: pool, - Conn: conn, - }, nil + pooled := &Pooled[C]{ + pool: pool, + Conn: conn, + } + now := monotonicNow() + pooled.timeUsed.set(now) + pooled.timeCreated.set(now) + return pooled, nil } func (pool *ConnPool[C]) getFromSettingsStack(setting *Setting) *Pooled[C] { @@ -475,7 +503,7 @@ func (pool *ConnPool[C]) getFromSettingsStack(setting *Setting) *Pooled[C] { for i := uint32(0); i <= stackMask; i++ { pos := (i + start) & stackMask - if conn, ok := pool.settings[pos].Pop(); ok { + if conn := pool.pop(&pool.settings[pos]); conn != nil { return conn } } @@ -509,7 +537,7 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { pool.Metrics.getCount.Add(1) // best case: if there's a connection in the clean stack, return it right away - if conn, ok := pool.clean.Pop(); ok { + if conn := pool.pop(&pool.clean); conn != nil { pool.borrowed.Add(1) return conn, nil } @@ -545,7 +573,7 @@ func (pool *ConnPool[C]) get(ctx context.Context) (*Pooled[C], error) { err = conn.Conn.ResetSetting(ctx) if err != nil { conn.Close() - err = pool.connReopen(ctx, conn, time.Now()) + err = pool.connReopen(ctx, conn, monotonicNow()) if err != nil { pool.closedConn() return nil, err @@ -563,10 +591,10 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( var err error // best case: check if there's a connection in the setting stack where our Setting belongs - conn, _ := pool.settings[setting.bucket&stackMask].Pop() + conn := pool.pop(&pool.settings[setting.bucket&stackMask]) // if there's connection with our setting, try popping a clean connection if conn == nil { - conn, _ = pool.clean.Pop() + conn = pool.pop(&pool.clean) } // otherwise try opening a brand new connection and we'll apply the setting to it if conn == nil { @@ -605,7 +633,7 @@ func (pool *ConnPool[C]) getWithSetting(ctx context.Context, setting *Setting) ( err = conn.Conn.ResetSetting(ctx) if err != nil { conn.Close() - err = pool.connReopen(ctx, conn, time.Now()) + err = pool.connReopen(ctx, conn, monotonicNow()) if err != nil { pool.closedConn() return nil, err @@ -667,7 +695,7 @@ func (pool *ConnPool[C]) setCapacity(ctx context.Context, newcap int64) error { // try closing from connections which are currently idle in the stacks conn := pool.getFromSettingsStack(nil) if conn == nil { - conn, _ = pool.clean.Pop() + conn = pool.pop(&pool.clean) } if conn == nil { time.Sleep(delay) @@ -689,21 +717,26 @@ func (pool *ConnPool[C]) closeIdleResources(now time.Time) { return } - var conns []*Pooled[C] + mono := monotonicFromTime(now) closeInStack := func(s *connStack[C]) { - conns = s.PopAll(conns[:0]) - slices.Reverse(conns) - - for _, conn := range conns { - if conn.timeUsed.Add(timeout).Sub(now) < 0 { + // Do a read-only best effort iteration of all the connection in this + // stack and atomically attempt to mark them as expired. + // Any connections that are marked as expired are _not_ removed from + // the stack; it's generally unsafe to remove nodes from the stack + // besides the head. When clients pop from the stack, they'll immediately + // notice the expired connection and ignore it. + // see: timestamp.expired + for conn := s.Peek(); conn != nil; conn = conn.next.Load() { + if conn.timeUsed.expired(mono, timeout) { pool.Metrics.idleClosed.Add(1) conn.Close() - pool.closedConn() - continue + // Using context.Background() is fine since MySQL connection already enforces + // a connect timeout via the `db_connect_timeout_ms` config param. + if err := pool.connReopen(context.Background(), conn, mono); err != nil { + pool.closedConn() + } } - - s.Push(conn) } } diff --git a/go/pools/smartconnpool/pool_test.go b/go/pools/smartconnpool/pool_test.go index a399bdfb3a4..cf0b18de252 100644 --- a/go/pools/smartconnpool/pool_test.go +++ b/go/pools/smartconnpool/pool_test.go @@ -36,12 +36,11 @@ var ( type TestState struct { lastID, open, close, reset atomic.Int64 - waits []time.Time mu sync.Mutex - - chaos struct { + waits []time.Time + chaos struct { delayConnect time.Duration - failConnect bool + failConnect atomic.Bool failApply bool } } @@ -109,7 +108,7 @@ func newConnector(state *TestState) Connector[*TestConn] { if state.chaos.delayConnect != 0 { time.Sleep(state.chaos.delayConnect) } - if state.chaos.failConnect { + if state.chaos.failConnect.Load() { return nil, fmt.Errorf("failed to connect: forced failure") } return &TestConn{ @@ -586,6 +585,45 @@ func TestUserClosing(t *testing.T) { } } +func TestConnReopen(t *testing.T) { + var state TestState + + p := NewPool(&Config[*TestConn]{ + Capacity: 1, + IdleTimeout: 200 * time.Millisecond, + MaxLifetime: 10 * time.Millisecond, + LogWait: state.LogWait, + }).Open(newConnector(&state), nil) + + defer p.Close() + + conn, err := p.Get(context.Background(), nil) + require.NoError(t, err) + assert.EqualValues(t, 1, state.lastID.Load()) + assert.EqualValues(t, 1, p.Active()) + + // wait enough to reach maxlifetime. + time.Sleep(50 * time.Millisecond) + + p.put(conn) + assert.EqualValues(t, 2, state.lastID.Load()) + assert.EqualValues(t, 1, p.Active()) + + // wait enough to reach idle timeout. + time.Sleep(300 * time.Millisecond) + assert.GreaterOrEqual(t, state.lastID.Load(), int64(3)) + assert.EqualValues(t, 1, p.Active()) + assert.GreaterOrEqual(t, p.Metrics.IdleClosed(), int64(1)) + + // mark connect to fail + state.chaos.failConnect.Store(true) + // wait enough to reach idle timeout and connect to fail. + time.Sleep(300 * time.Millisecond) + // no active connection should be left. + assert.Zero(t, p.Active()) + +} + func TestIdleTimeout(t *testing.T) { testTimeout := func(t *testing.T, setting *Setting) { var state TestState @@ -608,6 +646,7 @@ func TestIdleTimeout(t *testing.T) { conns = append(conns, r) } + assert.GreaterOrEqual(t, state.open.Load(), int64(5)) // wait a long while; ensure that none of the conns have been closed time.Sleep(1 * time.Second) @@ -619,12 +658,24 @@ func TestIdleTimeout(t *testing.T) { p.put(conn) } + time.Sleep(1 * time.Second) + for _, closed := range closers { - <-closed + select { + case <-closed: + default: + t.Fatalf("Connections remain open after 1 second") + } } + // At least 5 connections should have been closed by now. + assert.GreaterOrEqual(t, p.Metrics.IdleClosed(), int64(5), "At least 5 connections should have been closed by now.") + + // At any point, at least 4 connections should be open, with 1 either in the process of opening or already opened. + // The idle connection closer shuts down one connection at a time. + assert.GreaterOrEqual(t, state.open.Load(), int64(4)) - // no need to assert anything: all the connections in the pool should are idle-closed - // now and if they're not the test will timeout and fail + // The number of available connections in the pool should remain at 5. + assert.EqualValues(t, 5, p.Available()) } t.Run("WithoutSettings", func(t *testing.T) { testTimeout(t, nil) }) @@ -650,7 +701,7 @@ func TestIdleTimeoutCreateFail(t *testing.T) { // Change the factory before putting back // to prevent race with the idle closer, who will // try to use it. - state.chaos.failConnect = true + state.chaos.failConnect.Store(true) p.put(r) timeout := time.After(1 * time.Second) for p.Active() != 0 { @@ -661,7 +712,7 @@ func TestIdleTimeoutCreateFail(t *testing.T) { } } // reset factory for next run. - state.chaos.failConnect = false + state.chaos.failConnect.Store(false) } } @@ -752,7 +803,7 @@ func TestExtendedLifetimeTimeout(t *testing.T) { func TestCreateFail(t *testing.T) { var state TestState - state.chaos.failConnect = true + state.chaos.failConnect.Store(true) ctx := context.Background() p := NewPool(&Config[*TestConn]{ @@ -799,12 +850,12 @@ func TestCreateFailOnPut(t *testing.T) { require.NoError(t, err) // change factory to fail the put. - state.chaos.failConnect = true + state.chaos.failConnect.Store(true) p.put(nil) assert.Zero(t, p.Active()) // change back for next iteration. - state.chaos.failConnect = false + state.chaos.failConnect.Store(false) } } @@ -822,7 +873,7 @@ func TestSlowCreateFail(t *testing.T) { LogWait: state.LogWait, }).Open(newConnector(&state), nil) - state.chaos.failConnect = true + state.chaos.failConnect.Store(true) for i := 0; i < 3; i++ { go func() { @@ -841,7 +892,7 @@ func TestSlowCreateFail(t *testing.T) { default: } - state.chaos.failConnect = false + state.chaos.failConnect.Store(false) conn, err := p.Get(ctx, setting) require.NoError(t, err) diff --git a/go/pools/smartconnpool/stack.go b/go/pools/smartconnpool/stack.go index ba38e31cecf..af86fa354bd 100644 --- a/go/pools/smartconnpool/stack.go +++ b/go/pools/smartconnpool/stack.go @@ -25,6 +25,9 @@ import ( // connStack is a lock-free stack for Connection objects. It is safe to // use from several goroutines. type connStack[C Connection] struct { + // top is a pointer to the top node on the stack and to an increasing + // counter of pop operations, to prevent A-B-A races. + // See: https://en.wikipedia.org/wiki/ABA_problem top atomic2.PointerAndUint64[Pooled[C]] } @@ -55,24 +58,7 @@ func (s *connStack[C]) Pop() (*Pooled[C], bool) { } } -func (s *connStack[C]) PopAll(out []*Pooled[C]) []*Pooled[C] { - var oldHead *Pooled[C] - - for { - var popCount uint64 - oldHead, popCount = s.top.Load() - if oldHead == nil { - return out - } - if s.top.CompareAndSwap(oldHead, popCount, nil, popCount+1) { - break - } - runtime.Gosched() - } - - for oldHead != nil { - out = append(out, oldHead) - oldHead = oldHead.next.Load() - } - return out +func (s *connStack[C]) Peek() *Pooled[C] { + top, _ := s.top.Load() + return top } diff --git a/go/pools/smartconnpool/timestamp.go b/go/pools/smartconnpool/timestamp.go new file mode 100644 index 00000000000..961ff18a5c5 --- /dev/null +++ b/go/pools/smartconnpool/timestamp.go @@ -0,0 +1,94 @@ +package smartconnpool + +import ( + "math" + "sync/atomic" + "time" +) + +var monotonicRoot = time.Now() + +// timestamp is a monotonic point in time, stored as a number of +// nanoseconds since the monotonic root. This type is only 8 bytes +// and hence can always be accessed atomically +type timestamp struct { + nano atomic.Int64 +} + +// timestampExpired is a special value that means this timestamp is now past +// an arbitrary expiration point, and hence doesn't need to store +const timestampExpired = math.MaxInt64 + +// timestampBusy is a special value that means this timestamp no longer +// tracks an expiration point +const timestampBusy = math.MinInt64 + +// monotonicNow returns the current monotonic time as a time.Duration. +// This is a very efficient operation because time.Since performs direct +// substraction of monotonic times without considering the wall clock times. +func monotonicNow() time.Duration { + return time.Since(monotonicRoot) +} + +// monotonicFromTime converts a wall-clock time from time.Now into a +// monotonic timestamp. +// This is a very efficient operation because time.(*Time).Sub performs direct +// substraction of monotonic times without considering the wall clock times. +func monotonicFromTime(now time.Time) time.Duration { + return now.Sub(monotonicRoot) +} + +// set sets this timestamp to the given monotonic value +func (t *timestamp) set(mono time.Duration) { + t.nano.Store(int64(mono)) +} + +// get returns the monotonic time of this timestamp as the number of nanoseconds +// since the monotonic root. +func (t *timestamp) get() time.Duration { + return time.Duration(t.nano.Load()) +} + +// elapsed returns the number of nanoseconds that have passed since +// this timestamp was updated +func (t *timestamp) elapsed() time.Duration { + return monotonicNow() - t.get() +} + +// update sets this timestamp's value to the current monotonic time +func (t *timestamp) update() { + t.nano.Store(int64(monotonicNow())) +} + +// borrow attempts to borrow this timestamp atomically. +// It only succeeds if we can ensure that nobody else has marked +// this timestamp as expired. When succeeded, the timestamp +// is cleared as "busy" as it no longer tracks an expiration point. +func (t *timestamp) borrow() bool { + stamp := t.nano.Load() + switch stamp { + case timestampExpired: + return false + case timestampBusy: + panic("timestampBusy when borrowing a time") + default: + return t.nano.CompareAndSwap(stamp, timestampBusy) + } +} + +// expired attempts to atomically expire this timestamp. +// It only succeeds if we can ensure the timestamp hasn't been +// concurrently expired or borrowed. +func (t *timestamp) expired(now time.Duration, timeout time.Duration) bool { + stamp := t.nano.Load() + if stamp == timestampExpired { + return false + } + if stamp == timestampBusy { + return false + } + if now-time.Duration(stamp) > timeout { + return t.nano.CompareAndSwap(stamp, timestampExpired) + } + return false +}