Skip to content

Commit

Permalink
buffer & grpcsync: various cleanups and improvements (#6785)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley authored Nov 15, 2023
1 parent 424db25 commit b98104e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 58 deletions.
41 changes: 26 additions & 15 deletions internal/buffer/unbounded.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
// Package buffer provides an implementation of an unbounded buffer.
package buffer

import "sync"
import (
"errors"
"sync"
)

// Unbounded is an implementation of an unbounded buffer which does not use
// extra goroutines. This is typically used for passing updates from one entity
Expand All @@ -36,6 +39,7 @@ import "sync"
type Unbounded struct {
c chan any
closed bool
closing bool
mu sync.Mutex
backlog []any
}
Expand All @@ -45,39 +49,41 @@ func NewUnbounded() *Unbounded {
return &Unbounded{c: make(chan any, 1)}
}

var errBufferClosed = errors.New("Put called on closed buffer.Unbounded")

// Put adds t to the unbounded buffer.
func (b *Unbounded) Put(t any) {
func (b *Unbounded) Put(t any) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
if b.closing {
return errBufferClosed
}
if len(b.backlog) == 0 {
select {
case b.c <- t:
return
return nil
default:
}
}
b.backlog = append(b.backlog, t)
return nil
}

// Load sends the earliest buffered data, if any, onto the read channel
// returned by Get(). Users are expected to call this every time they read a
// Load sends the earliest buffered data, if any, onto the read channel returned
// by Get(). Users are expected to call this every time they successfully read a
// value from the read channel.
func (b *Unbounded) Load() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
b.backlog[0] = nil
b.backlog = b.backlog[1:]
default:
}
} else if b.closing && !b.closed {
close(b.c)
}
}

Expand All @@ -88,18 +94,23 @@ func (b *Unbounded) Load() {
// send the next buffered value onto the channel if there is any.
//
// If the unbounded buffer is closed, the read channel returned by this method
// is closed.
// is closed after all data is drained.
func (b *Unbounded) Get() <-chan any {
return b.c
}

// Close closes the unbounded buffer.
// Close closes the unbounded buffer. No subsequent data may be Put(), and the
// channel returned from Get() will be closed after all the data is read and
// Load() is called for the final time.
func (b *Unbounded) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
if b.closing {
return
}
b.closed = true
close(b.c)
b.closing = true
if len(b.backlog) == 0 {
b.closed = true
close(b.c)
}
}
21 changes: 16 additions & 5 deletions internal/buffer/unbounded_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func init() {
}

// TestSingleWriter starts one reader and one writer goroutine and makes sure
// that the reader gets all the value added to the buffer by the writer.
// that the reader gets all the values added to the buffer by the writer.
func (s) TestSingleWriter(t *testing.T) {
ub := NewUnbounded()
reads := []int{}
Expand Down Expand Up @@ -124,14 +124,25 @@ func (s) TestMultipleWriters(t *testing.T) {
// buffer is closed.
func (s) TestClose(t *testing.T) {
ub := NewUnbounded()
if err := ub.Put(1); err != nil {
t.Fatalf("Unbounded.Put() = %v; want nil", err)
}
ub.Close()
if v, ok := <-ub.Get(); ok {
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
if err := ub.Put(1); err == nil {
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
}
if v, ok := <-ub.Get(); !ok {
t.Errorf("Unbounded.Get() = %v, %v, want %v, %v", v, ok, 1, true)
}
if err := ub.Put(1); err == nil {
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
}
ub.Put(1)
ub.Load()
if v, ok := <-ub.Get(); ok {
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
}
ub.Close()
if err := ub.Put(1); err == nil {
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
}
ub.Close() // ignored
}
51 changes: 13 additions & 38 deletions internal/grpcsync/callback_serializer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package grpcsync

import (
"context"
"sync"

"google.golang.org/grpc/internal/buffer"
)
Expand All @@ -38,8 +37,6 @@ type CallbackSerializer struct {
done chan struct{}

callbacks *buffer.Unbounded
closedMu sync.Mutex
closed bool
}

// NewCallbackSerializer returns a new CallbackSerializer instance. The provided
Expand All @@ -65,56 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
// callbacks to be executed by the serializer. It is not possible to add
// callbacks once the context passed to NewCallbackSerializer is cancelled.
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
cs.closedMu.Lock()
defer cs.closedMu.Unlock()

if cs.closed {
return false
}
cs.callbacks.Put(f)
return true
return cs.callbacks.Put(f) == nil
}

func (cs *CallbackSerializer) run(ctx context.Context) {
var backlog []func(context.Context)

defer close(cs.done)

// TODO: when Go 1.21 is the oldest supported version, this loop and Close
// can be replaced with:
//
// context.AfterFunc(ctx, cs.callbacks.Close)
for ctx.Err() == nil {
select {
case <-ctx.Done():
// Do nothing here. Next iteration of the for loop will not happen,
// since ctx.Err() would be non-nil.
case callback, ok := <-cs.callbacks.Get():
if !ok {
return
}
case cb := <-cs.callbacks.Get():
cs.callbacks.Load()
callback.(func(ctx context.Context))(ctx)
cb.(func(context.Context))(ctx)
}
}

// Fetch pending callbacks if any, and execute them before returning from
// this method and closing cs.done.
cs.closedMu.Lock()
cs.closed = true
backlog = cs.fetchPendingCallbacks()
// Close the buffer to prevent new callbacks from being added.
cs.callbacks.Close()
cs.closedMu.Unlock()
for _, b := range backlog {
b(ctx)
}
}

func (cs *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) {
var backlog []func(context.Context)
for {
select {
case b := <-cs.callbacks.Get():
backlog = append(backlog, b.(func(context.Context)))
cs.callbacks.Load()
default:
return backlog
}
// Run all pending callbacks.
for cb := range cs.callbacks.Get() {
cs.callbacks.Load()
cb.(func(context.Context))(ctx)
}
}

Expand Down

0 comments on commit b98104e

Please sign in to comment.