diff --git a/README.md b/README.md index 74ba9fd2..39bcba5f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ + # lo - Iterate over slices, maps, channels... [![tag](https://img.shields.io/github/tag/samber/lo.svg)](https://github.com/samber/lo/releases) @@ -297,6 +298,7 @@ Concurrency helpers: - [AttemptWhileWithDelay](#attemptwhilewithdelay) - [Debounce](#debounce) - [DebounceBy](#debounceby) +- [Throttle](#throttle) - [Synchronize](#synchronize) - [Async](#async) - [Transaction](#transaction) @@ -3417,6 +3419,64 @@ cancel("second key") [[play](https://go.dev/play/p/d3Vpt6pxhY8)] +### Throttle + +Creates a throttled instance that invokes given functions only once in every interval. + +This returns 2 functions, First one is throttled function and Second one is a function to reset interval. + +```go +f := func() { + println("Called once in every 100ms") +} + +throttle, reset := lo.NewThrottle(100 * time.Millisecond, f) + +for j := 0; j < 10; j++ { + throttle() + time.Sleep(30 * time.Millisecond) +} + +reset() +throttle() +``` + +`NewThrottleWithCount` is NewThrottle with count limit, throttled function will be invoked count times in every interval. + +```go +f := func() { + println("Called three times in every 100ms") +} + +throttle, reset := lo.NewThrottleWithCount(100 * time.Millisecond, f) + +for j := 0; j < 10; j++ { + throttle() + time.Sleep(30 * time.Millisecond) +} + +reset() +throttle() +``` + +`NewThrottleBy` and `NewThrottleByWithCount` are NewThrottle with sharding key, throttled function will be invoked count times in every interval. + +```go +f := func(key string) { + println(key, "Called three times in every 100ms") +} + +throttle, reset := lo.NewThrottleByWithCount(100 * time.Millisecond, f) + +for j := 0; j < 10; j++ { + throttle("foo") + time.Sleep(30 * time.Millisecond) +} + +reset() +throttle() +``` + ### Synchronize Wraps the underlying callback in a mutex. It receives an optional mutex. diff --git a/retry.go b/retry.go index 82e8f82f..5b9cef3d 100644 --- a/retry.go +++ b/retry.go @@ -287,4 +287,89 @@ func (t *Transaction[T]) Process(state T) (T, error) { return state, err } -// throttle ? +// @TODO: single mutex per key ? +type throttleBy[T comparable] struct { + mu *sync.Mutex + timer *time.Timer + interval time.Duration + callbacks []func(key T) + countLimit int + count map[T]int +} + +func (th *throttleBy[T]) throttledFunc(key T) { + th.mu.Lock() + defer th.mu.Unlock() + + if _, ok := th.count[key]; !ok { + th.count[key] = 0 + } + + if th.count[key] < th.countLimit { + th.count[key]++ + + for _, f := range th.callbacks { + f(key) + } + + } + if th.timer == nil { + th.timer = time.AfterFunc(th.interval, func() { + th.reset() + }) + } +} + +func (th *throttleBy[T]) reset() { + th.mu.Lock() + defer th.mu.Unlock() + + if th.timer != nil { + th.timer.Stop() + } + + th.count = map[T]int{} + th.timer = nil +} + +// NewThrottle creates a throttled instance that invokes given functions only once in every interval. +// This returns 2 functions, First one is throttled function and Second one is a function to reset interval +func NewThrottle(interval time.Duration, f ...func()) (throttle func(), reset func()) { + return NewThrottleWithCount(interval, 1, f...) +} + +// NewThrottleWithCount is NewThrottle with count limit, throttled function will be invoked count times in every interval. +func NewThrottleWithCount(interval time.Duration, count int, f ...func()) (throttle func(), reset func()) { + callbacks := Map(f, func(item func(), _ int) func(struct{}) { + return func(struct{}) { + item() + } + }) + + throttleFn, reset := NewThrottleByWithCount[struct{}](interval, count, callbacks...) + return func() { + throttleFn(struct{}{}) + }, reset +} + +// NewThrottleBy creates a throttled instance that invokes given functions only once in every interval. +// This returns 2 functions, First one is throttled function and Second one is a function to reset interval +func NewThrottleBy[T comparable](interval time.Duration, f ...func(key T)) (throttle func(key T), reset func()) { + return NewThrottleByWithCount[T](interval, 1, f...) +} + +// NewThrottleByWithCount is NewThrottleBy with count limit, throttled function will be invoked count times in every interval. +func NewThrottleByWithCount[T comparable](interval time.Duration, count int, f ...func(key T)) (throttle func(key T), reset func()) { + if count <= 0 { + count = 1 + } + + th := &throttleBy[T]{ + mu: new(sync.Mutex), + interval: interval, + callbacks: f, + countLimit: count, + count: map[T]int{}, + } + return th.throttledFunc, th.reset +} diff --git a/retry_example_test.go b/retry_example_test.go index 3560c2f0..d08b0f0d 100644 --- a/retry_example_test.go +++ b/retry_example_test.go @@ -249,3 +249,92 @@ func ExampleTransaction_error() { // -5 // error } + +func ExampleNewThrottle() { + throttle, reset := NewThrottle(100*time.Millisecond, func() { + fmt.Println("Called once in every 100ms") + }) + + for j := 0; j < 10; j++ { + throttle() + time.Sleep(30 * time.Millisecond) + } + + reset() + + // Output: + // Called once in every 100ms + // Called once in every 100ms + // Called once in every 100ms +} + +func ExampleNewThrottleWithCount() { + throttle, reset := NewThrottleWithCount(100*time.Millisecond, 2, func() { + fmt.Println("Called once in every 100ms") + }) + + for j := 0; j < 10; j++ { + throttle() + time.Sleep(30 * time.Millisecond) + } + + reset() + + // Output: + // Called once in every 100ms + // Called once in every 100ms + // Called once in every 100ms + // Called once in every 100ms + // Called once in every 100ms + // Called once in every 100ms +} + +func ExampleNewThrottleBy() { + throttle, reset := NewThrottleBy(100*time.Millisecond, func(key string) { + fmt.Println(key, "Called once in every 100ms") + }) + + for j := 0; j < 10; j++ { + throttle("foo") + throttle("bar") + time.Sleep(30 * time.Millisecond) + } + + reset() + + // Output: + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms +} + +func ExampleNewThrottleByWithCount() { + throttle, reset := NewThrottleByWithCount(100*time.Millisecond, 2, func(key string) { + fmt.Println(key, "Called once in every 100ms") + }) + + for j := 0; j < 10; j++ { + throttle("foo") + throttle("bar") + time.Sleep(30 * time.Millisecond) + } + + reset() + + // Output: + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms + // foo Called once in every 100ms + // bar Called once in every 100ms +} diff --git a/retry_test.go b/retry_test.go index f4094a76..121afa57 100644 --- a/retry_test.go +++ b/retry_test.go @@ -498,3 +498,146 @@ func TestTransaction(t *testing.T) { is.Equal(assert.AnError, err) } } + +func TestNewThrottle(t *testing.T) { + t.Parallel() + is := assert.New(t) + callCount := 0 + f1 := func() { + callCount++ + } + th, reset := NewThrottle(10*time.Millisecond, f1) + + is.Equal(0, callCount) + for j := 0; j < 100; j++ { + th() + } + is.Equal(1, callCount) + + time.Sleep(15 * time.Millisecond) + + for j := 0; j < 100; j++ { + th() + } + + is.Equal(2, callCount) + + // reset counter + reset() + th() + is.Equal(3, callCount) + +} + +func TestNewThrottleWithCount(t *testing.T) { + t.Parallel() + is := assert.New(t) + callCount := 0 + f1 := func() { + callCount++ + } + th, reset := NewThrottleWithCount(10*time.Millisecond, 3, f1) + + // the function does not throttle for initial count number + for i := 0; i < 20; i++ { + th() + } + is.Equal(3, callCount) + + time.Sleep(11 * time.Millisecond) + + for i := 0; i < 20; i++ { + th() + } + + is.Equal(6, callCount) + + reset() + for i := 0; i < 20; i++ { + th() + } + + is.Equal(9, callCount) +} + +func TestNewThrottleBy(t *testing.T) { + t.Parallel() + is := assert.New(t) + callCountA := 0 + callCountB := 0 + f1 := func(key string) { + if key == "a" { + callCountA++ + } else { + callCountB++ + } + } + th, reset := NewThrottleBy[string](10*time.Millisecond, f1) + + is.Equal(0, callCountA) + is.Equal(0, callCountB) + for j := 0; j < 100; j++ { + th("a") + th("b") + } + is.Equal(1, callCountA) + is.Equal(1, callCountB) + + time.Sleep(15 * time.Millisecond) + + for j := 0; j < 100; j++ { + th("a") + th("b") + } + + is.Equal(2, callCountA) + is.Equal(2, callCountB) + + // reset counter + reset() + th("a") + is.Equal(3, callCountA) + is.Equal(2, callCountB) + +} + +func TestNewThrottleByWithCount(t *testing.T) { + t.Parallel() + is := assert.New(t) + callCountA := 0 + callCountB := 0 + f1 := func(key string) { + if key == "a" { + callCountA++ + } else { + callCountB++ + } + } + th, reset := NewThrottleByWithCount(10*time.Millisecond, 3, f1) + + // the function does not throttle for initial count number + for i := 0; i < 20; i++ { + th("a") + th("b") + } + is.Equal(3, callCountA) + is.Equal(3, callCountB) + + time.Sleep(11 * time.Millisecond) + + for i := 0; i < 20; i++ { + th("a") + th("b") + } + + is.Equal(6, callCountA) + is.Equal(6, callCountB) + + reset() + for i := 0; i < 20; i++ { + th("a") + } + + is.Equal(9, callCountA) + is.Equal(6, callCountB) +}