diff --git a/context.go b/context.go new file mode 100644 index 0000000..e19d2a6 --- /dev/null +++ b/context.go @@ -0,0 +1,24 @@ +package puddle + +import ( + "context" + "time" +) + +// valueCancelCtx combines two contexts into one. One context is used for values and the other is used for cancellation. +type valueCancelCtx struct { + valueCtx context.Context + cancelCtx context.Context +} + +func (ctx *valueCancelCtx) Deadline() (time.Time, bool) { return ctx.cancelCtx.Deadline() } +func (ctx *valueCancelCtx) Done() <-chan struct{} { return ctx.cancelCtx.Done() } +func (ctx *valueCancelCtx) Err() error { return ctx.cancelCtx.Err() } +func (ctx *valueCancelCtx) Value(key any) any { return ctx.valueCtx.Value(key) } + +func newValueCancelCtx(valueCtx, cancelContext context.Context) context.Context { + return &valueCancelCtx{ + valueCtx: valueCtx, + cancelCtx: cancelContext, + } +} diff --git a/export_test.go b/export_test.go new file mode 100644 index 0000000..36e8df6 --- /dev/null +++ b/export_test.go @@ -0,0 +1,9 @@ +package puddle + +import "context" + +func (p *Pool[T]) AcquireRaw(ctx context.Context) (*Resource[T], error) { + return p.acquire(ctx) +} + +var AcquireSemAll = acquireSemAll diff --git a/go.mod b/go.mod index b9efbe2..39e64bd 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/jackc/puddle/v2 -go 1.18 +go 1.19 require github.com/stretchr/testify v1.8.0 require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.4.0 // indirect + golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f1ea297..4cea6fc 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,16 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7 h1:ZrnxWX62AgTKOSagEqxvb3ffipvEDX2pl7E1TdqLqIc= +golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/genstack/gen_stack.go b/internal/genstack/gen_stack.go new file mode 100644 index 0000000..7e4660c --- /dev/null +++ b/internal/genstack/gen_stack.go @@ -0,0 +1,85 @@ +package genstack + +// GenStack implements a generational stack. +// +// GenStack works as common stack except for the fact that all elements in the +// older generation are guaranteed to be popped before any element in the newer +// generation. New elements are always pushed to the current (newest) +// generation. +// +// We could also say that GenStack behaves as a stack in case of a single +// generation, but it behaves as a queue of individual generation stacks. +type GenStack[T any] struct { + // We can represent arbitrary number of generations using 2 stacks. The + // new stack stores all new pushes and the old stack serves all reads. + // Old stack can represent multiple generations. If old == new, then all + // elements pushed in previous (not current) generations have already + // been popped. + + old *stack[T] + new *stack[T] +} + +// NewGenStack creates a new empty GenStack. +func NewGenStack[T any]() *GenStack[T] { + s := &stack[T]{} + return &GenStack[T]{ + old: s, + new: s, + } +} + +func (s *GenStack[T]) Pop() (T, bool) { + // Pushes always append to the new stack, so if the old once becomes + // empty, it will remail empty forever. + if s.old.len() == 0 && s.old != s.new { + s.old = s.new + } + + if s.old.len() == 0 { + var zero T + return zero, false + } + + return s.old.pop(), true +} + +// Push pushes a new element at the top of the stack. +func (s *GenStack[T]) Push(v T) { s.new.push(v) } + +// NextGen starts a new stack generation. +func (s *GenStack[T]) NextGen() { + if s.old == s.new { + s.new = &stack[T]{} + return + } + + // We need to pop from the old stack to the top of the new stack. Let's + // have an example: + // + // Old: 4 3 2 1 + // New: 8 7 6 5 + // PopOrder: 1 2 3 4 5 6 7 8 + // + // + // To preserve pop order, we have to take all elements from the old + // stack and push them to the top of new stack: + // + // New: 8 7 6 5 4 3 2 1 + // + s.new.push(s.old.takeAll()...) + + // We have the old stack allocated and empty, so why not to reuse it as + // new new stack. + s.old, s.new = s.new, s.old +} + +// Len returns number of elements in the stack. +func (s *GenStack[T]) Len() int { + l := s.old.len() + if s.old != s.new { + l += s.new.len() + } + + return l +} diff --git a/internal/genstack/gen_stack_test.go b/internal/genstack/gen_stack_test.go new file mode 100644 index 0000000..519bd3b --- /dev/null +++ b/internal/genstack/gen_stack_test.go @@ -0,0 +1,90 @@ +package genstack + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func requirePopEmpty[T any](t testing.TB, s *GenStack[T]) { + v, ok := s.Pop() + require.False(t, ok) + require.Zero(t, v) +} + +func requirePop[T any](t testing.TB, s *GenStack[T], expected T) { + v, ok := s.Pop() + require.True(t, ok) + require.Equal(t, expected, v) +} + +func TestGenStack_Empty(t *testing.T) { + s := NewGenStack[int]() + requirePopEmpty(t, s) +} + +func TestGenStack_SingleGen(t *testing.T) { + r := require.New(t) + s := NewGenStack[int]() + + s.Push(1) + s.Push(2) + r.Equal(2, s.Len()) + + requirePop(t, s, 2) + requirePop(t, s, 1) + requirePopEmpty(t, s) +} + +func TestGenStack_TwoGen(t *testing.T) { + r := require.New(t) + s := NewGenStack[int]() + + s.Push(3) + s.Push(4) + s.Push(5) + r.Equal(3, s.Len()) + s.NextGen() + r.Equal(3, s.Len()) + s.Push(6) + s.Push(7) + r.Equal(5, s.Len()) + + requirePop(t, s, 5) + requirePop(t, s, 4) + requirePop(t, s, 3) + requirePop(t, s, 7) + requirePop(t, s, 6) + requirePopEmpty(t, s) +} + +func TestGenStack_MuptiGen(t *testing.T) { + r := require.New(t) + s := NewGenStack[int]() + + s.Push(10) + s.Push(11) + s.Push(12) + r.Equal(3, s.Len()) + s.NextGen() + r.Equal(3, s.Len()) + s.Push(13) + s.Push(14) + r.Equal(5, s.Len()) + s.NextGen() + r.Equal(5, s.Len()) + s.Push(15) + s.Push(16) + s.Push(17) + r.Equal(8, s.Len()) + + requirePop(t, s, 12) + requirePop(t, s, 11) + requirePop(t, s, 10) + requirePop(t, s, 14) + requirePop(t, s, 13) + requirePop(t, s, 17) + requirePop(t, s, 16) + requirePop(t, s, 15) + requirePopEmpty(t, s) +} diff --git a/internal/genstack/stack.go b/internal/genstack/stack.go new file mode 100644 index 0000000..dbced0c --- /dev/null +++ b/internal/genstack/stack.go @@ -0,0 +1,39 @@ +package genstack + +// stack is a wrapper around an array implementing a stack. +// +// We cannot use slice to represent the stack because append might change the +// pointer value of the slice. That would be an issue in GenStack +// implementation. +type stack[T any] struct { + arr []T +} + +// push pushes a new element at the top of a stack. +func (s *stack[T]) push(vs ...T) { s.arr = append(s.arr, vs...) } + +// pop pops the stack top-most element. +// +// If stack length is zero, this method panics. +func (s *stack[T]) pop() T { + idx := s.len() - 1 + val := s.arr[idx] + + // Avoid memory leak + var zero T + s.arr[idx] = zero + + s.arr = s.arr[:idx] + return val +} + +// takeAll returns all elements in the stack in order as they are stored - i.e. +// the top-most stack element is the last one. +func (s *stack[T]) takeAll() []T { + arr := s.arr + s.arr = nil + return arr +} + +// len returns number of elements in the stack. +func (s *stack[T]) len() int { return len(s.arr) } diff --git a/internal_test.go b/internal_test.go deleted file mode 100644 index a061b73..0000000 --- a/internal_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package puddle - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRemoveResourcePanicsWithBugReportIfResourceDoesNotExist(t *testing.T) { - s := []*Resource[any]{new(Resource[any]), new(Resource[any]), new(Resource[any])} - assert.PanicsWithValue(t, "BUG: removeResource could not find res in slice", func() { removeResource(s, new(Resource[any])) }) -} diff --git a/log.go b/log.go new file mode 100644 index 0000000..b21b946 --- /dev/null +++ b/log.go @@ -0,0 +1,32 @@ +package puddle + +import "unsafe" + +type ints interface { + int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 +} + +// log2Int returns log2 of an integer. This function panics if val < 0. For val +// == 0, returns 0. +func log2Int[T ints](val T) uint8 { + if val <= 0 { + panic("log2 of non-positive number does not exist") + } + + return log2IntRange(val, 0, uint8(8*unsafe.Sizeof(val))) +} + +func log2IntRange[T ints](val T, begin, end uint8) uint8 { + length := end - begin + if length == 1 { + return begin + } + + delim := begin + length/2 + mask := T(1) << delim + if mask > val { + return log2IntRange(val, begin, delim) + } else { + return log2IntRange(val, delim, end) + } +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 0000000..d425313 --- /dev/null +++ b/log_test.go @@ -0,0 +1,49 @@ +package puddle + +import ( + "math" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestLog2Uint(t *testing.T) { + r := require.New(t) + + r.Equal(uint8(0), log2Int(1)) + r.Equal(uint8(0), log2Int[uint64](1)) + r.Equal(uint8(1), log2Int[uint32](2)) + r.Equal(uint8(7), log2Int[uint8](math.MaxUint8)) + r.Equal(uint8(15), log2Int[uint16](math.MaxUint16)) + r.Equal(uint8(31), log2Int[uint32](math.MaxUint32)) + r.Equal(uint8(63), log2Int[uint64](math.MaxUint64)) + + r.Panics(func() { log2Int[uint64](0) }) + r.Panics(func() { log2Int[int64](-1) }) +} + +func FuzzLog2Uint(f *testing.F) { + const cnt = 1000 + + rand := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < cnt; i++ { + val := uint64(rand.Int63()) + // val + 1 not to test val == 0. + f.Add(val + 1) + } + + f.Fuzz(func(t *testing.T, val uint64) { + var mx uint8 + for i := 63; i >= 0; i-- { + mask := uint64(1) << i + if mask&val != 0 { + mx = uint8(i) + break + } + } + + require.Equal(t, mx, log2Int(val)) + }) +} diff --git a/pool.go b/pool.go index 3708ebe..fef38d3 100644 --- a/pool.go +++ b/pool.go @@ -4,7 +4,11 @@ import ( "context" "errors" "sync" + "sync/atomic" "time" + + "github.com/jackc/puddle/v2/internal/genstack" + "golang.org/x/sync/semaphore" ) const ( @@ -112,11 +116,21 @@ func (res *Resource[T]) IdleDuration() time.Duration { // Pool is a concurrency-safe resource pool. type Pool[T any] struct { - cond *sync.Cond - destructWG *sync.WaitGroup - - allResources []*Resource[T] - idleResources []*Resource[T] + // mux is the pool internal lock. Any modification of shared state of + // the pool (but Acquires of acquireSem) must be performed only by + // holder of the lock. Long running operations are not allowed when mux + // is held. + mux sync.Mutex + // acquireSem provides an allowance to acquire a resource. + // + // Releases are allowed only when caller holds mux. Acquires have to + // happen before mux is locked (doesn't apply to semaphore.TryAcquire in + // AcquireAllIdle). + acquireSem *semaphore.Weighted + destructWG sync.WaitGroup + + allResources resList[T] + idleResources *genstack.GenStack[*Resource[T]] constructor Constructor[T] destructor Destructor[T] @@ -125,12 +139,12 @@ type Pool[T any] struct { acquireCount int64 acquireDuration time.Duration emptyAcquireCount int64 - canceledAcquireCount int64 + canceledAcquireCount atomic.Int64 resetCount int baseAcquireCtx context.Context - cancelBaseAcquireCtx func() + cancelBaseAcquireCtx context.CancelFunc closed bool } @@ -149,8 +163,8 @@ func NewPool[T any](config *Config[T]) (*Pool[T], error) { baseAcquireCtx, cancelBaseAcquireCtx := context.WithCancel(context.Background()) return &Pool[T]{ - cond: sync.NewCond(new(sync.Mutex)), - destructWG: &sync.WaitGroup{}, + acquireSem: semaphore.NewWeighted(int64(config.MaxSize)), + idleResources: genstack.NewGenStack[*Resource[T]](), maxSize: config.MaxSize, constructor: config.Constructor, destructor: config.Destructor, @@ -162,25 +176,21 @@ func NewPool[T any](config *Config[T]) (*Pool[T], error) { // Close destroys all resources in the pool and rejects future Acquire calls. // Blocks until all resources are returned to pool and destroyed. func (p *Pool[T]) Close() { - p.cond.L.Lock() + defer p.destructWG.Wait() + + p.mux.Lock() + defer p.mux.Unlock() + if p.closed { - p.cond.L.Unlock() return } p.closed = true p.cancelBaseAcquireCtx() - for _, res := range p.idleResources { - p.allResources = removeResource(p.allResources, res) + for res, ok := p.idleResources.Pop(); ok; res, ok = p.idleResources.Pop() { + p.allResources.remove(res) go p.destructResourceValue(res.value) } - p.idleResources = nil - p.cond.L.Unlock() - - // Wake up all go routines waiting for a resource to be returned so they can terminate. - p.cond.Broadcast() - - p.destructWG.Wait() } // Stat is a snapshot of Pool statistics. @@ -249,12 +259,14 @@ func (s *Stat) CanceledAcquireCount() int64 { // Stat returns the current pool statistics. func (p *Pool[T]) Stat() *Stat { - p.cond.L.Lock() + p.mux.Lock() + defer p.mux.Unlock() + s := &Stat{ maxResources: p.maxSize, acquireCount: p.acquireCount, emptyAcquireCount: p.emptyAcquireCount, - canceledAcquireCount: p.canceledAcquireCount, + canceledAcquireCount: p.canceledAcquireCount.Load(), acquireDuration: p.acquireDuration, } @@ -269,20 +281,42 @@ func (p *Pool[T]) Stat() *Stat { } } - p.cond.L.Unlock() return s } -// valueCancelCtx combines two contexts into one. One context is used for values and the other is used for cancellation. -type valueCancelCtx struct { - valueCtx context.Context - cancelCtx context.Context +// tryAcquireIdleResource checks if there is any idle resource. If there is +// some, this method removes it from idle list and returns it. If the idle pool +// is empty, this method returns nil and doesn't modify the idleResources slice. +// +// WARNING: Caller of this method must hold the pool mutex! +func (p *Pool[T]) tryAcquireIdleResource() *Resource[T] { + res, ok := p.idleResources.Pop() + if !ok { + return nil + } + + res.status = resourceStatusAcquired + return res } -func (ctx *valueCancelCtx) Deadline() (time.Time, bool) { return ctx.cancelCtx.Deadline() } -func (ctx *valueCancelCtx) Done() <-chan struct{} { return ctx.cancelCtx.Done() } -func (ctx *valueCancelCtx) Err() error { return ctx.cancelCtx.Err() } -func (ctx *valueCancelCtx) Value(key any) any { return ctx.valueCtx.Value(key) } +// createNewResource creates a new resource and inserts it into list of pool +// resources. +// +// WARNING: Caller of this method must hold the pool mutex! +func (p *Pool[T]) createNewResource() *Resource[T] { + res := &Resource[T]{ + pool: p, + creationTime: time.Now(), + lastUsedNano: nanotime(), + poolResetCount: p.resetCount, + status: resourceStatusConstructing, + } + + p.allResources.append(res) + p.destructWG.Add(1) + + return res +} // Acquire gets a resource from the pool. If no resources are available and the pool is not at maximum capacity it will // create a new resource. If the pool is at maximum capacity it will block until a resource is available. ctx can be @@ -292,131 +326,127 @@ func (ctx *valueCancelCtx) Value(key any) any { return ctx.valueCtx.Va // ctx. Canceling ctx will cause Acquire to return immediately but it will not cancel the resource creation. This avoids // the problem of it being impossible to create resources when the time to create a resource is greater than any one // caller of Acquire is willing to wait. -func (p *Pool[T]) Acquire(ctx context.Context) (*Resource[T], error) { - startNano := nanotime() - if doneChan := ctx.Done(); doneChan != nil { - select { - case <-ctx.Done(): - p.cond.L.Lock() - p.canceledAcquireCount += 1 - p.cond.L.Unlock() - return nil, ctx.Err() - default: - } +func (p *Pool[T]) Acquire(ctx context.Context) (_ *Resource[T], err error) { + select { + case <-ctx.Done(): + p.canceledAcquireCount.Add(1) + return nil, ctx.Err() + default: } - p.cond.L.Lock() + return p.acquire(ctx) +} - emptyAcquire := false +// acquire is a continuation of Acquire function that doesn't check context +// validity. +// +// This function exists solely only for benchmarking purposes. +func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { + startNano := nanotime() - for { - if p.closed { - p.cond.L.Unlock() - return nil, ErrClosedPool + var waitedForLock bool + if !p.acquireSem.TryAcquire(1) { + waitedForLock = true + err := p.acquireSem.Acquire(ctx, 1) + if err != nil { + p.canceledAcquireCount.Add(1) + return nil, err } + } - // If a resource is available now - if len(p.idleResources) > 0 { - res := p.idleResources[len(p.idleResources)-1] - p.idleResources[len(p.idleResources)-1] = nil // Avoid memory leak - p.idleResources = p.idleResources[:len(p.idleResources)-1] - res.status = resourceStatusAcquired - if emptyAcquire { - p.emptyAcquireCount += 1 - } - p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) - p.cond.L.Unlock() - return res, nil + p.mux.Lock() + if p.closed { + p.acquireSem.Release(1) + p.mux.Unlock() + return nil, ErrClosedPool + } + + // If a resource is available in the pool. + if res := p.tryAcquireIdleResource(); res != nil { + if waitedForLock { + p.emptyAcquireCount += 1 } + p.acquireCount += 1 + p.acquireDuration += time.Duration(nanotime() - startNano) + p.mux.Unlock() + return res, nil + } + + if len(p.allResources) >= int(p.maxSize) { + // Unreachable code. + panic("bug: semaphore allowed more acquires than pool allows") + } - emptyAcquire = true - - // If there is room to create a resource do so - if len(p.allResources) < int(p.maxSize) { - res := &Resource[T]{pool: p, creationTime: time.Now(), lastUsedNano: nanotime(), poolResetCount: p.resetCount, status: resourceStatusConstructing} - p.allResources = append(p.allResources, res) - p.destructWG.Add(1) - p.cond.L.Unlock() - - // Create the resource in a goroutine to immediately return from Acquire if ctx is canceled without also canceling - // the constructor. See: https://github.com/jackc/pgx/issues/1287 and https://github.com/jackc/pgx/issues/1259 - constructErrCh := make(chan error) - go func() { - constructorCtx := &valueCancelCtx{valueCtx: ctx, cancelCtx: p.baseAcquireCtx} - value, err := p.constructResourceValue(constructorCtx) - p.cond.L.Lock() - if err != nil { - p.allResources = removeResource(p.allResources, res) - p.destructWG.Done() - - constructErrCh <- err - - p.cond.L.Unlock() - p.cond.Signal() - return - } - - res.value = value - res.status = resourceStatusAcquired - - select { - case constructErrCh <- nil: - p.emptyAcquireCount += 1 - p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) - p.cond.L.Unlock() - // No need to call Signal as this new resource was immediately acquired and did not change availability for - // any waiting Acquire calls. - case <-ctx.Done(): - p.cond.L.Unlock() - p.releaseAcquiredResource(res, res.lastUsedNano) - } - }() + // The resource is not idle, but there is enough space to create one. + res := p.createNewResource() + p.mux.Unlock() + + res, err := p.initResourceValue(ctx, res) + if err != nil { + return nil, err + } + + p.mux.Lock() + defer p.mux.Unlock() + + p.emptyAcquireCount += 1 + p.acquireCount += 1 + p.acquireDuration += time.Duration(nanotime() - startNano) + + return res, nil +} + +func (p *Pool[T]) initResourceValue(ctx context.Context, res *Resource[T]) (*Resource[T], error) { + // Create the resource in a goroutine to immediately return from Acquire + // if ctx is canceled without also canceling the constructor. + // + // See: + // - https://github.com/jackc/pgx/issues/1287 + // - https://github.com/jackc/pgx/issues/1259 + constructErrChan := make(chan error) + go func() { + constructorCtx := newValueCancelCtx(ctx, p.baseAcquireCtx) + value, err := p.constructor(constructorCtx) + if err != nil { + p.mux.Lock() + p.allResources.remove(res) + p.destructWG.Done() + + // The resource won't be acquired because its + // construction failed. We have to allow someone else to + // take that resouce. + p.acquireSem.Release(1) + p.mux.Unlock() select { + case constructErrChan <- err: case <-ctx.Done(): - p.cond.L.Lock() - p.canceledAcquireCount += 1 - p.cond.L.Unlock() - return nil, ctx.Err() - case err := <-constructErrCh: - if err != nil { - return nil, err - } - // we don't call signal here because we didn't change the resource pools - // at all so waking anything else up won't help - return res, nil + // The caller is cancelled, so no-one awaits the + // error. This branch avoid goroutine leak. } + return } - if ctx.Done() == nil { - p.cond.Wait() - } else { - // Convert p.cond.Wait into a channel - waitChan := make(chan struct{}, 1) - go func() { - p.cond.Wait() - waitChan <- struct{}{} - }() + res.value = value + res.status = resourceStatusAcquired - select { - case <-ctx.Done(): - // Allow goroutine waiting for signal to exit. Re-signal since we couldn't - // do anything with it. Another goroutine might be waiting. - go func() { - <-waitChan - p.cond.L.Unlock() - p.cond.Signal() - }() - - p.cond.L.Lock() - p.canceledAcquireCount += 1 - p.cond.L.Unlock() - return nil, ctx.Err() - case <-waitChan: - } + // This select works because the channel is unbuffered. + select { + case constructErrChan <- nil: + case <-ctx.Done(): + p.releaseAcquiredResource(res, res.lastUsedNano) + } + }() + + select { + case <-ctx.Done(): + p.canceledAcquireCount.Add(1) + return nil, ctx.Err() + case err := <-constructErrChan: + if err != nil { + return nil, err } + return res, nil } } @@ -424,82 +454,149 @@ func (p *Pool[T]) Acquire(ctx context.Context) (*Resource[T], error) { // resources are available but the pool has room to grow, a resource will be created in the background. ctx is only // used to cancel the background creation. func (p *Pool[T]) TryAcquire(ctx context.Context) (*Resource[T], error) { - p.cond.L.Lock() - defer p.cond.L.Unlock() + if !p.acquireSem.TryAcquire(1) { + return nil, ErrNotAvailable + } + + p.mux.Lock() + defer p.mux.Unlock() if p.closed { + p.acquireSem.Release(1) return nil, ErrClosedPool } // If a resource is available now - if len(p.idleResources) > 0 { - res := p.idleResources[len(p.idleResources)-1] - p.idleResources[len(p.idleResources)-1] = nil // Avoid memory leak - p.idleResources = p.idleResources[:len(p.idleResources)-1] + if res := p.tryAcquireIdleResource(); res != nil { p.acquireCount += 1 - res.status = resourceStatusAcquired return res, nil } - if len(p.allResources) < int(p.maxSize) { - res := &Resource[T]{pool: p, creationTime: time.Now(), lastUsedNano: nanotime(), poolResetCount: p.resetCount, status: resourceStatusConstructing} - p.allResources = append(p.allResources, res) - p.destructWG.Add(1) + if len(p.allResources) >= int(p.maxSize) { + // Unreachable code. + panic("bug: semaphore allowed more acquires than pool allows") + } - go func() { - value, err := p.constructResourceValue(ctx) - defer p.cond.Signal() - p.cond.L.Lock() - defer p.cond.L.Unlock() + res := p.createNewResource() + go func() { + value, err := p.constructor(ctx) + + p.mux.Lock() + defer p.mux.Unlock() + // We have to create the resource and only then release the + // semaphore - For the time being there is no resource that + // someone could acquire. + defer p.acquireSem.Release(1) + + if err != nil { + p.allResources.remove(res) + p.destructWG.Done() + return + } - if err != nil { - p.allResources = removeResource(p.allResources, res) - p.destructWG.Done() - return - } + res.value = value + res.status = resourceStatusIdle + p.idleResources.Push(res) + }() + + return nil, ErrNotAvailable +} + +// acquireSemAll tries to acquire num free tokens from sem. This function is +// guaranteed to acquire at least the lowest number of tokens that has been +// available in the semaphore during runtime of this function. +// +// For the time being, semaphore doesn't allow to acquire all tokens atomically +// (see https://github.com/golang/sync/pull/19). We simulate this by trying all +// powers of 2 that are less or equal to num. +// +// For example, let's immagine we have 19 free tokens in the semaphore which in +// total has 24 tokens (i.e. the maxSize of the pool is 24 resources). Then if +// num is 24, the log2Uint(24) is 4 and we try to acquire 16, 8, 4, 2 and 1 +// tokens. Out of those, the acquire of 16, 2 and 1 tokens will succeed. +// +// Naturally, Acquires and Releases of the semaphore might take place +// concurrently. For this reason, it's not guaranteed that absolutely all free +// tokens in the semaphore will be acquired. But it's guaranteed that at least +// the minimal number of tokens that has been present over the whole process +// will be acquired. This is sufficient for the use-case we have in this +// package. +// +// TODO: Replace this with acquireSem.TryAcquireAll() if it gets to +// upstream. https://github.com/golang/sync/pull/19 +func acquireSemAll(sem *semaphore.Weighted, num int) int { + if sem.TryAcquire(int64(num)) { + return num + } - res.value = value - res.status = resourceStatusIdle - p.idleResources = append(p.idleResources, res) - }() + var acquired int + for i := int(log2Int(num)); i >= 0; i-- { + val := 1 << i + if sem.TryAcquire(int64(val)) { + acquired += val + } } - return nil, ErrNotAvailable + return acquired } -// AcquireAllIdle atomically acquires all currently idle resources. Its intended -// use is for health check and keep-alive functionality. It does not update pool +// AcquireAllIdle acquires all currently idle resources. Its intended use is for +// health check and keep-alive functionality. It does not update pool // statistics. func (p *Pool[T]) AcquireAllIdle() []*Resource[T] { - p.cond.L.Lock() + p.mux.Lock() + defer p.mux.Unlock() + if p.closed { - p.cond.L.Unlock() return nil } - for _, res := range p.idleResources { + numIdle := p.idleResources.Len() + if numIdle == 0 { + return nil + } + + // In acquireSemAll we use only TryAcquire and not Acquire. Because + // TryAcquire cannot block, the fact that we hold mutex locked and try + // to acquire semaphore cannot result in dead-lock. + // + // Because the mutex is locked, no parallel Release can run. This + // implies that the number of tokens can only decrease because some + // Acquire/TryAcquire call can consume the semaphore token. Consequently + // acquired is always less or equal to numIdle. Moreover if acquired < + // numIdle, then there are some parallel Acquire/TryAcquire calls that + // will take the remaining idle connections. + acquired := acquireSemAll(p.acquireSem, numIdle) + + idle := make([]*Resource[T], acquired) + for i := range idle { + res, _ := p.idleResources.Pop() res.status = resourceStatusAcquired + idle[i] = res } - resources := p.idleResources // Swap out current slice - p.idleResources = nil - p.cond.L.Unlock() - return resources + // We have to bump the generation to ensure that Acquire/TryAcquire + // calls running in parallel (those which caused acquired < numIdle) + // will consume old connections and not freshly released connections + // instead. + p.idleResources.NextGen() + + return idle } // CreateResource constructs a new resource without acquiring it. // It goes straight in the IdlePool. It does not check against maxSize. // It can be useful to maintain warm resources under little load. func (p *Pool[T]) CreateResource(ctx context.Context) error { - p.cond.L.Lock() + p.mux.Lock() if p.closed { - p.cond.L.Unlock() + p.mux.Unlock() return ErrClosedPool } p.destructWG.Add(1) p.cond.L.Unlock() - value, err := p.constructResourceValue(ctx) + value, err := p.constructor(ctx) if err != nil { p.destructWG.Done() return err @@ -514,16 +611,16 @@ func (p *Pool[T]) CreateResource(ctx context.Context) error { poolResetCount: p.resetCount, } - p.cond.L.Lock() + p.mux.Lock() + defer p.mux.Unlock() + // If closed while constructing resource then destroy it and return an error if p.closed { go p.destructResourceValue(res.value) - p.cond.L.Unlock() return ErrClosedPool } - p.allResources = append(p.allResources, res) - p.idleResources = append(p.idleResources, res) - p.cond.L.Unlock() + p.allResources.append(res) + p.idleResources.Push(res) return nil } @@ -534,71 +631,53 @@ func (p *Pool[T]) CreateResource(ctx context.Context) error { // It is safe to reset a pool while resources are checked out. Those resources will be destroyed when they are returned // to the pool. func (p *Pool[T]) Reset() { - p.cond.L.Lock() - defer p.cond.L.Unlock() + p.mux.Lock() + defer p.mux.Unlock() p.resetCount++ - for i := range p.idleResources { - p.allResources = removeResource(p.allResources, p.idleResources[i]) - go p.destructResourceValue(p.idleResources[i].value) - p.idleResources[i] = nil + for res, ok := p.idleResources.Pop(); ok; res, ok = p.idleResources.Pop() { + p.allResources.remove(res) + go p.destructResourceValue(res.value) } - p.idleResources = p.idleResources[0:0] } // releaseAcquiredResource returns res to the the pool. func (p *Pool[T]) releaseAcquiredResource(res *Resource[T], lastUsedNano int64) { - p.cond.L.Lock() + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) if p.closed || res.poolResetCount != p.resetCount { - p.allResources = removeResource(p.allResources, res) + p.allResources.remove(res) go p.destructResourceValue(res.value) } else { res.lastUsedNano = lastUsedNano res.status = resourceStatusIdle - p.idleResources = append(p.idleResources, res) + p.idleResources.Push(res) } - - p.cond.L.Unlock() - p.cond.Signal() } // Remove removes res from the pool and closes it. If res is not part of the // pool Remove will panic. func (p *Pool[T]) destroyAcquiredResource(res *Resource[T]) { p.destructResourceValue(res.value) - p.cond.L.Lock() - p.allResources = removeResource(p.allResources, res) - p.cond.L.Unlock() - p.cond.Signal() + + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) + + p.allResources.remove(res) } func (p *Pool[T]) hijackAcquiredResource(res *Resource[T]) { - p.cond.L.Lock() + p.mux.Lock() + defer p.mux.Unlock() + defer p.acquireSem.Release(1) - p.allResources = removeResource(p.allResources, res) + p.allResources.remove(res) res.status = resourceStatusHijacked p.destructWG.Done() // not responsible for destructing hijacked resources - - p.cond.L.Unlock() - p.cond.Signal() -} - -func removeResource[T any](slice []*Resource[T], res *Resource[T]) []*Resource[T] { - for i := range slice { - if slice[i] == res { - slice[i] = slice[len(slice)-1] - slice[len(slice)-1] = nil // Avoid memory leak - return slice[:len(slice)-1] - } - } - - panic("BUG: removeResource could not find res in slice") -} - -func (p *Pool[T]) constructResourceValue(ctx context.Context) (T, error) { - return p.constructor(ctx) } func (p *Pool[T]) destructResourceValue(value T) { diff --git a/pool_test.go b/pool_test.go index 8f2b77d..45ee822 100644 --- a/pool_test.go +++ b/pool_test.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/puddle/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/semaphore" ) type Counter struct { @@ -27,18 +28,18 @@ type Counter struct { // Next increments the counter and returns the value func (c *Counter) Next() int { c.mutex.Lock() + defer c.mutex.Unlock() + c.n += 1 - n := c.n - c.mutex.Unlock() - return n + return c.n } // Value returns the counter func (c *Counter) Value() int { c.mutex.Lock() - n := c.n - c.mutex.Unlock() - return n + defer c.mutex.Unlock() + + return c.n } func createConstructor() (puddle.Constructor[int], *Counter) { @@ -894,6 +895,17 @@ func TestSignalIsSentWhenResourceFailedToCreate(t *testing.T) { wg.Wait() } +func stressTestDur(t testing.TB) time.Duration { + s := os.Getenv("STRESS_TEST_DURATION") + if s == "" { + s = "1s" + } + + dur, err := time.ParseDuration(s) + require.Nil(t, err) + return dur +} + func TestStress(t *testing.T) { constructor, _ := createConstructor() var destructorCalls Counter @@ -980,10 +992,10 @@ func TestStress(t *testing.T) { for i := 0; i < workerCount; i++ { wg.Add(1) go func() { + defer wg.Done() for { select { case <-finishChan: - wg.Done() return default: } @@ -993,18 +1005,116 @@ func TestStress(t *testing.T) { }() } - s := os.Getenv("STRESS_TEST_DURATION") - if s == "" { - s = "1s" - } - testDuration, err := time.ParseDuration(s) - require.Nil(t, err) - time.AfterFunc(testDuration, func() { close(finishChan) }) + time.AfterFunc(stressTestDur(t), func() { close(finishChan) }) wg.Wait() - pool.Close() } +func TestStress_AcquireAllIdle_TryAcquire(t *testing.T) { + r := require.New(t) + + pool := testPool[int32](t) + + var wg sync.WaitGroup + done := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + idleRes := pool.AcquireAllIdle() + r.Less(len(idleRes), 2) + for _, res := range idleRes { + res.Release() + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + res, err := pool.TryAcquire(context.Background()) + if err != nil { + r.Equal(puddle.ErrNotAvailable, err) + } else { + r.NotNil(res) + res.Release() + } + } + }() + + time.AfterFunc(stressTestDur(t), func() { close(done) }) + wg.Wait() +} + +func TestStress_AcquireAllIdle_Acquire(t *testing.T) { + r := require.New(t) + + pool := testPool[int32](t) + + var wg sync.WaitGroup + done := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + idleRes := pool.AcquireAllIdle() + r.Less(len(idleRes), 2) + for _, res := range idleRes { + r.NotNil(res) + res.Release() + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + for { + select { + case <-done: + return + default: + } + + res, err := pool.Acquire(context.Background()) + if err != nil { + r.Equal(puddle.ErrNotAvailable, err) + } else { + r.NotNil(res) + res.Release() + } + } + }() + + time.AfterFunc(stressTestDur(t), func() { close(done) }) + wg.Wait() +} + func startAcceptOnceDummyServer(laddr string) { ln, err := net.Listen("tcp", laddr) if err != nil { @@ -1170,3 +1280,196 @@ func BenchmarkPoolAcquireAndRelease(b *testing.B) { }) } } + +func TestAcquireAllSem(t *testing.T) { + r := require.New(t) + + sem := semaphore.NewWeighted(5) + r.Equal(4, puddle.AcquireSemAll(sem, 4)) + sem.Release(4) + + r.Equal(5, puddle.AcquireSemAll(sem, 5)) + sem.Release(5) + + r.Equal(5, puddle.AcquireSemAll(sem, 6)) + sem.Release(5) +} + +func testPool[T any](t testing.TB) *puddle.Pool[T] { + cfg := puddle.Config[T]{ + MaxSize: 1, + Constructor: func(ctx context.Context) (T, error) { + var zero T + return zero, nil + }, + Destructor: func(T) {}, + } + + pool, err := puddle.NewPool(&cfg) + require.NoError(t, err) + t.Cleanup(pool.Close) + + return pool +} + +func releaser[T any](t testing.TB) chan<- *puddle.Resource[T] { + startChan := make(chan struct{}) + workChan := make(chan *puddle.Resource[T], 1) + + go func() { + close(startChan) + + for r := range workChan { + r.Release() + } + }() + t.Cleanup(func() { close(workChan) }) + + // Wait for goroutine start. + <-startChan + return workChan +} + +func TestReleaseAfterAcquire(t *testing.T) { + const cnt = 100000 + + r := require.New(t) + ctx := context.Background() + pool := testPool[int32](t) + releaseChan := releaser[int32](t) + + res, err := pool.Acquire(ctx) + r.NoError(err) + // We need to release the last connection. Otherwise the pool.Close() + // method will block and this function will never return. + defer func() { res.Release() }() + + for i := 0; i < cnt; i++ { + releaseChan <- res + res, err = pool.Acquire(ctx) + r.NoError(err) + } +} + +func BenchmarkAcquire_ReleaseAfterAcquire(b *testing.B) { + r := require.New(b) + ctx := context.Background() + pool := testPool[int32](b) + releaseChan := releaser[int32](b) + + res, err := pool.Acquire(ctx) + r.NoError(err) + // We need to release the last connection. Otherwise the pool.Close() + // method will block and this function will never return. + defer func() { res.Release() }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + releaseChan <- res + res, err = pool.Acquire(ctx) + r.NoError(err) + } +} + +func withCPULoad() { + // Multiply by 2 to similate overload of the system. + numGoroutines := runtime.NumCPU() * 2 + + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + wg.Done() + + // Similate computationally intensive task. + for j := 0; true; j++ { + } + }() + } + + wg.Wait() +} + +func BenchmarkAcquire_ReleaseAfterAcquireWithCPULoad(b *testing.B) { + r := require.New(b) + ctx := context.Background() + pool := testPool[int32](b) + releaseChan := releaser[int32](b) + + withCPULoad() + + res, err := pool.Acquire(ctx) + r.NoError(err) + // We need to release the last connection. Otherwise the pool.Close() + // method will block and this function will never return. + defer func() { res.Release() }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + releaseChan <- res + res, err = pool.Acquire(ctx) + r.NoError(err) + } +} + +func BenchmarkAcquire_MultipleCancelled(b *testing.B) { + const cancelCnt = 64 + + r := require.New(b) + ctx := context.Background() + pool := testPool[int32](b) + releaseChan := releaser[int32](b) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + res, err := pool.Acquire(ctx) + r.NoError(err) + // We need to release the last connection. Otherwise the pool.Close() + // method will block and this function will never return. + defer func() { res.Release() }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < cancelCnt; j++ { + _, err = pool.AcquireRaw(cancelCtx) + r.Equal(context.Canceled, err) + } + + releaseChan <- res + res, err = pool.Acquire(ctx) + r.NoError(err) + } +} + +func BenchmarkAcquire_MultipleCancelledWithCPULoad(b *testing.B) { + const cancelCnt = 3 + + r := require.New(b) + ctx := context.Background() + pool := testPool[int32](b) + releaseChan := releaser[int32](b) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + withCPULoad() + + res, err := pool.Acquire(ctx) + r.NoError(err) + // We need to release the last connection. Otherwise the pool.Close() + // method will block and this function will never return. + defer func() { res.Release() }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < cancelCnt; j++ { + _, err = pool.AcquireRaw(cancelCtx) + r.Equal(context.Canceled, err) + } + + releaseChan <- res + res, err = pool.Acquire(ctx) + r.NoError(err) + } +} diff --git a/resource_list.go b/resource_list.go new file mode 100644 index 0000000..b243095 --- /dev/null +++ b/resource_list.go @@ -0,0 +1,28 @@ +package puddle + +type resList[T any] []*Resource[T] + +func (l *resList[T]) append(val *Resource[T]) { *l = append(*l, val) } + +func (l *resList[T]) popBack() *Resource[T] { + idx := len(*l) - 1 + val := (*l)[idx] + (*l)[idx] = nil // Avoid memory leak + *l = (*l)[:idx] + + return val +} + +func (l *resList[T]) remove(val *Resource[T]) { + for i, elem := range *l { + if elem == val { + lastIdx := len(*l) - 1 + (*l)[i] = (*l)[lastIdx] + (*l)[lastIdx] = nil // Avoid memory leak + (*l) = (*l)[:lastIdx] + return + } + } + + panic("BUG: removeResource could not find res in slice") +} diff --git a/resource_list_test.go b/resource_list_test.go new file mode 100644 index 0000000..7104189 --- /dev/null +++ b/resource_list_test.go @@ -0,0 +1,62 @@ +package puddle + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResList_Append(t *testing.T) { + r := require.New(t) + + arr := []*Resource[any]{ + new(Resource[any]), + new(Resource[any]), + new(Resource[any]), + } + + list := resList[any](arr) + + list.append(new(Resource[any])) + r.Len(list, 4) + list.append(new(Resource[any])) + r.Len(list, 5) + list.append(new(Resource[any])) + r.Len(list, 6) +} + +func TestResList_PopBack(t *testing.T) { + r := require.New(t) + + arr := []*Resource[any]{ + new(Resource[any]), + new(Resource[any]), + new(Resource[any]), + } + + list := resList[any](arr) + + list.popBack() + r.Len(list, 2) + list.popBack() + r.Len(list, 1) + list.popBack() + r.Len(list, 0) + + r.Panics(func() { list.popBack() }) +} + +func TestResList_PanicsWithBugReportIfResourceDoesNotExist(t *testing.T) { + arr := []*Resource[any]{ + new(Resource[any]), + new(Resource[any]), + new(Resource[any]), + } + + list := resList[any](arr) + + assert.PanicsWithValue(t, "BUG: removeResource could not find res in slice", func() { + list.remove(new(Resource[any])) + }) +}