Skip to content

Commit

Permalink
batch: refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Viktor Login <[email protected]>
  • Loading branch information
batazor committed Mar 8, 2025
1 parent e0a566f commit 5da6cae
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 66 deletions.
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ linters-settings:

gocritic:
enable-all: true
disabled-checks:
- unnamedResult

godot:
# Comments to be checked: `declarations`, `toplevel`, or `all`.
Expand Down
84 changes: 46 additions & 38 deletions pkg/batch/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
## Batch Processing

This package offers a robust batch processing system for aggregating and processing items efficiently in batches.
It's designed with concurrency and efficiency in mind, aligning with Go's concurrency patterns.
This package offers a robust batch processing system for aggregating and processing items efficiently in batches.
It's designed with concurrency and efficiency in mind, following Go's concurrency patterns.

### Features

- **Batch processing:** Groups items for efficient bulk processing.
- **Concurrency Safe:** Thread-safe for reliable operation under concurrent loads.
- **Configurable:** Allows for custom batch sizes and tick intervals.
- **Context Support:** Supports graceful shutdowns and cancellations.
- **Configurable:** Custom batch sizes and flush intervals via options.
- **Context Support:** Graceful shutdowns and cancellations without storing contexts.
- **Generics:** Utilizes Go's generics for type safety.
- **Error Reporting:** Callback errors are reported through an error channel.

### Usage

Expand All @@ -27,40 +28,47 @@ import (
)

func main() {
ctx := context.Background()

// Define the callback function
callback := func(items []*batch.Item[string]) error {
for _, item := range items {
// Process item
time.Sleep(time.Millisecond * 10) // Simulate work
item.CallbackChannel <- item.Item + " processed"
close(item.CallbackChannel)
}
return nil
}

// Create a new batch processor
b, err := batch.New(ctx, callback, batch.WithSize, batch.WithInterval[string](time.Second))
if err != nil {
panic(err)
}

// Push items into the batch processor
for i := 0; i < 20; i++ {
resChan := b.Push(fmt.Sprintf("Item %d", i))
go func(ch chan string) {
result, ok := <-ch
if ok {
fmt.Println(result)
} else {
fmt.Println("Channel closed before processing")
}
}(resChan)
}

// Wait to ensure all items are processed
time.Sleep(2 * time.Second)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Define the callback function to process a batch of items.
callback := func(items []*batch.Item[string]) error {
for _, item := range items {
// Simulate processing work.
time.Sleep(10 * time.Millisecond)
item.CallbackChannel <- item.Item + " processed"
close(item.CallbackChannel)
}

return nil
}

// Create a new batch processor with custom options.
// Note: New returns an error channel to report callback errors.
b, errChan := batch.New(ctx, callback, batch.WithSize[string](5), batch.WithInterval[string](time.Second))

// Process errors from the error channel.
go func() {
for err := range errChan {
fmt.Println("Error:", err)
}
}()

// Push items into the batch processor.
for i := 0; i < 20; i++ {
resChan := b.Push(fmt.Sprintf("Item %d", i))

go func(ch chan string) {
if result, ok := <-ch; ok {
fmt.Println(result)
} else {
fmt.Println("Channel closed before processing")
}
}(resChan)
}

// Wait to ensure all items are processed.
time.Sleep(2 * time.Second)
}
```

Expand Down
68 changes: 56 additions & 12 deletions pkg/batch/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,42 @@ import (
"github.com/spf13/viper"
)

const defaultErrChanBuffer = 10

// New creates a new batch with a specified callback function.
func New[T any](ctx context.Context, callback func([]*Item[T]) error, opts ...Option[T]) (*Batch[T], error) {
func New[T any](ctx context.Context, callback func([]*Item[T]) error, opts ...Option[T]) (*Batch[T], <-chan error) {
viper.SetDefault("BATCH_INTERVAL", "100ms")
viper.SetDefault("BATCH_SIZE", 100)
viper.SetDefault("BATCH_ERROR_BUFFER", defaultErrChanBuffer)

batch := &Batch[T]{
ctx: ctx,
mu: sync.Mutex{},
wg: sync.WaitGroup{},
mu: sync.Mutex{},
wg: sync.WaitGroup{},

callback: callback,
items: []*Item[T]{},
interval: viper.GetDuration("BATCH_INTERVAL"),
size: viper.GetInt("BATCH_SIZE"),
// Instead of storing ctx, use a done channel.
done: make(chan struct{}),
// Buffered error channel to report errors from callback.
errChan: make(chan error, viper.GetInt("BATCH_ERROR_BUFFER")),
}

// Apply options
for _, opt := range opts {
opt(batch)
}

// Launch a goroutine to monitor the passed context.
go func() {
<-ctx.Done()
close(batch.done)
}()

go batch.run()

return batch, nil
return batch, batch.errChan
}

// Push adds an item to the batch.
Expand All @@ -44,8 +56,9 @@ func (batch *Batch[T]) Push(item T) chan T {
Item: item,
}

// Check for cancellation using the done channel.
select {
case <-batch.ctx.Done():
case <-batch.done:
close(newItem.CallbackChannel)
return newItem.CallbackChannel
default:
Expand All @@ -56,7 +69,7 @@ func (batch *Batch[T]) Push(item T) chan T {
shouldFlush := len(batch.items) >= batch.size
batch.mu.Unlock()

// If the batch is full, flush it
// If the batch is full, flush it.
if shouldFlush {
go batch.flushItems()
}
Expand All @@ -71,10 +84,11 @@ func (batch *Batch[T]) run() {

for {
select {
case <-batch.ctx.Done():
case <-batch.done:
batch.flushItems()
batch.wg.Wait()
batch.closePendingChannels()
close(batch.errChan)

return
case <-ticker.C:
Expand All @@ -83,6 +97,7 @@ func (batch *Batch[T]) run() {
}
}

// closePendingChannels closes all pending channels.
func (batch *Batch[T]) closePendingChannels() {
batch.mu.Lock()
defer batch.mu.Unlock()
Expand All @@ -97,19 +112,48 @@ func (batch *Batch[T]) flushItems() {
batch.mu.Lock()
items := batch.items
batch.items = nil

// Check if cancellation has already occurred while still holding the lock.
doneClosed := false

select {
case <-batch.done:
doneClosed = true
default:
}

batch.mu.Unlock()

if len(items) == 0 {
return
}

if doneClosed {
for _, item := range items {
close(item.CallbackChannel)
}

return
}

batch.wg.Add(1)

go func() {
go func(items []*Item[T]) {
defer batch.wg.Done()

if err := batch.callback(items); err != nil {
// Handle error if necessary
// Check cancellation again before proceeding.
select {
case <-batch.done:
for _, item := range items {
close(item.CallbackChannel)
}
default:
if err := batch.callback(items); err != nil {
select {
case batch.errChan <- err:
default:
}
}
}
}()
}(items)
}
36 changes: 22 additions & 14 deletions pkg/batch/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestNew(t *testing.T) {

aggrCB := func(args []*Item[string]) error {
for _, item := range args {
time.Sleep(time.Microsecond * 100) // Emulate long work
time.Sleep(100 * time.Microsecond) // Emulate long work

item.CallbackChannel <- item.Item
close(item.CallbackChannel)
Expand All @@ -36,35 +36,43 @@ func TestNew(t *testing.T) {
return nil
}

b, err := New(ctx, aggrCB)
require.NoError(t, err)
b, errChan := New(ctx, aggrCB)
require.NotNil(t, b)
require.NotNil(t, errChan)

requests := []string{"A", "B", "C", "D"}
for _, request := range requests {
res := b.Push(request)

req := request // Capture range variable
eg.Go(func() error {
val, ok := <-res
require.True(t, ok)
require.Equal(t, req, val)
require.Equal(t, request, val)

return nil
})
}

err = eg.Wait()
require.NoError(t, err)
require.NoError(t, eg.Wait())

// Cancel the context to trigger cleanup.
cancelFunc()

// Drain the error channel (should be closed without any error)
for range errChan {
// No errors expected.
}
})

t.Run("Check context cancellation", func(t *testing.T) {
ctx, cancelFunc := context.WithTimeout(context.Background(), time.Millisecond*10)
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancelFunc()

eg, ctx := errgroup.WithContext(ctx)

aggrCB := func(args []*Item[string]) error {
for _, item := range args {
time.Sleep(time.Second * 10) // Emulate long work
time.Sleep(10 * time.Second) // Emulate long work

item.CallbackChannel <- item.Item
close(item.CallbackChannel)
Expand All @@ -74,21 +82,21 @@ func TestNew(t *testing.T) {
}

requests := []string{"A", "B", "C", "D"}

b, err := New(ctx, aggrCB)
require.NoError(t, err)
b, errChan := New(ctx, aggrCB)
require.NotNil(t, b)
require.NotNil(t, errChan)

for _, request := range requests {
res := b.Push(request)

eg.Go(func() error {
_, ok := <-res
require.False(t, ok)

return nil
})
}

err = eg.Wait()
require.NoError(t, err)
require.NoError(t, eg.Wait())
})
}
4 changes: 2 additions & 2 deletions pkg/batch/type.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package batch

import (
"context"
"sync"
"time"
)

// Batch is a structure for batch processing
type Batch[T any] struct {
ctx context.Context
done chan struct{}
errChan chan error
callback func([]*Item[T]) error
items []*Item[T]
wg sync.WaitGroup
Expand Down

0 comments on commit 5da6cae

Please sign in to comment.