diff --git a/README.md b/README.md index dd848c39..4fffdeb8 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,7 @@ Concurrency helpers: - [AttemptWhileWithDelay](#attemptwhilewithdelay) - [Debounce](#debounce) - [DebounceBy](#debounceby) +- [Throttle](#throttle) - [Synchronize](#synchronize) - [Async](#async) - [Transaction](#transaction) @@ -2529,6 +2530,47 @@ cancel("second key") [[play](https://go.dev/play/p/d3Vpt6pxhY8)] +### Throttle +`NewThrottle` creates a throttled instance that invokes given functions only once in every interval. +This returns 2 functions, First one is throttled function and Second one is a function to reset interval. + +```go + +f := func() { + println("Called once in every 100ms") +} + +throttle, reset := lo.NewThrottle(100 * time.Millisecond, f) + +for j := 0; j < 10; j++ { + throttle() + time.Sleep(30 * time.Millisecond) +} + +reset() +throttle() + +``` + +`NewThrottleWithCount` is NewThrottle with count limit, throttled function will be invoked count times in every interval. +```go + +f := func() { + println("Called three times in every 100ms") +} + +throttle, reset := lo.NewThrottle(100 * time.Millisecond, f) + +for j := 0; j < 10; j++ { + throttle() + time.Sleep(30 * time.Millisecond) +} + +reset() +throttle() + +``` + ### Synchronize Wraps the underlying callback in a mutex. It receives an optional mutex. diff --git a/retry.go b/retry.go index 11f456d2..eda897ea 100644 --- a/retry.go +++ b/retry.go @@ -287,4 +287,62 @@ func (t *Transaction[T]) Process(state T) (T, error) { return state, err } -// throttle ? +type throttle struct { + mu *sync.Mutex + timer *time.Timer + interval time.Duration + callbacks []func() + countLimit int + count int +} + +func (th *throttle) throttledFunc() { + th.mu.Lock() + defer th.mu.Unlock() + if th.count < th.countLimit { + th.count++ + + for _, f := range th.callbacks { + f() + } + + } + if th.timer == nil { + th.timer = time.AfterFunc(th.interval, func() { + th.reset() + }) + } +} + +func (th *throttle) reset() { + th.mu.Lock() + defer th.mu.Unlock() + + if th.timer != nil { + th.timer.Stop() + } + + th.count = 0 + th.timer = nil + +} + +// NewThrottle creates a throttled instance that invokes given functions only once in every interval. +// This returns 2 functions, First one is throttled function and Second one is a function to reset interval +func NewThrottle(interval time.Duration, f ...func()) (func(), func()) { + return NewThrottleWithCount(interval, 1, f...) +} + +// NewThrottleWithCount is NewThrottle with count limit, throttled function will be invoked count times in every interval. +func NewThrottleWithCount(interval time.Duration, count int, f ...func()) (func(), func()) { + if count <= 0 { + count = 1 + } + th := &throttle{ + mu: new(sync.Mutex), + interval: interval, + callbacks: f, + countLimit: count, + } + return th.throttledFunc, th.reset +} diff --git a/retry_test.go b/retry_test.go index 1ac00703..5e44fa95 100644 --- a/retry_test.go +++ b/retry_test.go @@ -498,3 +498,66 @@ func TestTransaction(t *testing.T) { is.Equal(assert.AnError, err) } } + +func TestNewThrottle(t *testing.T) { + t.Parallel() + is := assert.New(t) + callCount := 0 + f1 := func() { + callCount++ + } + th, reset := NewThrottle(10*time.Millisecond, f1) + + is.Equal(0, callCount) + for i := 0; i < 9; i++ { + var wg sync.WaitGroup + for j := 0; j < 100; j++ { + wg.Add(1) + go func() { + defer wg.Done() + th() + }() + } + wg.Wait() + time.Sleep(3 * time.Millisecond) + } + // 35 ms passed + is.Equal(3, callCount) + + // reset counter + reset() + th() + is.Equal(4, callCount) + +} + +func TestNewThrottleWithCount(t *testing.T) { + t.Parallel() + is := assert.New(t) + callCount := 0 + f1 := func() { + callCount++ + } + th, reset := NewThrottleWithCount(10*time.Millisecond, 3, f1) + + // the function does not throttle for initial count number + for i := 0; i < 20; i++ { + th() + } + is.Equal(3, callCount) + + time.Sleep(11 * time.Millisecond) + + for i := 0; i < 20; i++ { + th() + } + + is.Equal(6, callCount) + + reset() + for i := 0; i < 20; i++ { + th() + } + + is.Equal(9, callCount) +}