diff --git a/batcher/batcher.go b/batcher/batcher.go index 2d1ccb4..6aed64a 100644 --- a/batcher/batcher.go +++ b/batcher/batcher.go @@ -19,6 +19,7 @@ type Batcher struct { lock sync.Mutex submit chan *work doWork func([]interface{}) error + done chan bool } // New constructs a new batcher that will batch all calls to Run that occur within @@ -70,6 +71,7 @@ func (b *Batcher) submitWork(w *work) { defer b.lock.Unlock() if b.submit == nil { + b.done = make(chan bool) b.submit = make(chan *work, 4) go b.batch() } @@ -95,14 +97,35 @@ func (b *Batcher) batch() { future <- ret close(future) } + close(b.done) } func (b *Batcher) timer() { time.Sleep(b.timeout) + b.flush() +} + +// Shutdown flush the changes and wait to be saved +func (b *Batcher) Shutdown(wait bool) { + b.flush() + + if wait { + // wait done channel + <-b.done + } +} + +// Flush saves the changes before the timer expires. +// It is useful to flush the changes when you shutdown your application +func (b *Batcher) flush() { b.lock.Lock() defer b.lock.Unlock() + if b.submit == nil { + return + } + close(b.submit) b.submit = nil } diff --git a/batcher/batcher_test.go b/batcher/batcher_test.go index f1b8d40..bff13a7 100644 --- a/batcher/batcher_test.go +++ b/batcher/batcher_test.go @@ -41,6 +41,53 @@ func TestBatcherSuccess(t *testing.T) { } } +func TestShutdownSuccess(t *testing.T) { + sleepDuration := 5 * time.Millisecond + durationLimit := 2 * sleepDuration + timeout := 2 * durationLimit + total := 0 + doSum := func(params []interface{}) error { + for _, param := range params { + intValue, ok := param.(int) + if !ok { + t.Error("expected type int") + } + total += intValue + } + return nil + } + + b := New(timeout, doSum) + go func() { + time.Sleep(sleepDuration) + b.Shutdown(true) + }() + + wg := &sync.WaitGroup{} + expectedTotal := 0 + start := time.Now() + for i := 0; i < 10; i++ { + expectedTotal += i + wg.Add(1) + go func(i int) { + if err := b.Run(i); err != nil { + t.Error(err) + } + wg.Done() + }(i) + } + wg.Wait() + + duration := time.Since(start) + if duration >= durationLimit { + t.Errorf("expected duration[%v] < durationLimit[%v]", duration, durationLimit) + } + + if total != expectedTotal { + t.Errorf("expected processed count[%v] < actual[%v]", expectedTotal, total) + } +} + func TestBatcherError(t *testing.T) { b := New(10*time.Millisecond, returnsError)