Skip to content

Commit

Permalink
adds aborted handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
wcharczuk committed Jul 18, 2024
1 parent 2a8818f commit 330dcae
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 4 deletions.
14 changes: 13 additions & 1 deletion node.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ type Node struct {
// onUpdateHandlers are functions that are called when the node updates.
// they are added with `OnUpdate(...)`.
onUpdateHandlers []func(context.Context)
// onErrorHandlers are functions that are called when the node updates.
// onErrorHandlers are functions that are called when the node errors in stabilization.
// they are added with `OnError(...)`.
onErrorHandlers []func(context.Context, error)
// onAbortedHandlers are functions that are called when the node is
// pre-empted for update by another node erroring.
// they are added with `OnError(...)`.
onAbortedHandlers []func(context.Context, error)
// stabilizeFn is set during initialization and is a shortcut
// to the interface sniff for the node for the IStabilize interface.
stabilizeFn func(context.Context) error
Expand Down Expand Up @@ -142,6 +146,14 @@ func (n *Node) OnError(fn func(context.Context, error)) {
n.onErrorHandlers = append(n.onErrorHandlers, fn)
}

// OnAborted registers an aborted handler.
//
// An aborted handler is called when the stabilize or cutoff
// function for this node is pre-empted by another node erroring.
func (n *Node) OnAborted(fn func(context.Context, error)) {
n.onAbortedHandlers = append(n.onAbortedHandlers, fn)
}

// Label returns a descriptive label for the node or
// an empty string if one hasn't been provided.
func (n *Node) Label() string {
Expand Down
7 changes: 6 additions & 1 deletion parallel_stabilize.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ func (graph *Graph) parallelStabilize(ctx context.Context) (err error) {
}
if err != nil {
// clear if there is an error!
graph.recomputeHeap.clear()
aborted := graph.recomputeHeap.clear()
for _, node := range aborted {
for _, ah := range node.Node().onAbortedHandlers {
ah(ctx, err)
}
}
}
if len(immediateRecompute) > 0 {
graph.recomputeHeap.mu.Lock()
Expand Down
5 changes: 5 additions & 0 deletions parallel_stabilize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,13 @@ func Test_ParallelStabilize_error(t *testing.T) {
ctx := testContext()
g := New()

var didCallAbortedHandler bool
v0 := Var(g, "hello")
m0 := Map(g, v0, ident)
m1 := Map(g, m0, ident)
m1.Node().OnAborted(func(_ context.Context, err error) {
didCallAbortedHandler = true
})

f0 := Func(g, func(ctx context.Context) (string, error) {
return "", fmt.Errorf("this is only a test")
Expand All @@ -140,6 +144,7 @@ func Test_ParallelStabilize_error(t *testing.T) {

testutil.Equal(t, false, g.recomputeHeap.has(m1), "we should clear the recompute heap on error")
testutil.Equal(t, false, g.recomputeHeap.has(f0))
testutil.Equal(t, true, didCallAbortedHandler)
}

func Test_ParallelStabilize_Always(t *testing.T) {
Expand Down
5 changes: 4 additions & 1 deletion recompute_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@ type recomputeHeap struct {
numItems int
}

func (rh *recomputeHeap) clear() {
func (rh *recomputeHeap) clear() (aborted []INode) {
rh.mu.Lock()
defer rh.mu.Unlock()

var next INode
for rh.numItems > 0 {
aborted = make([]INode, 0, rh.numItems)
next, _ = rh.removeMinUnsafe()
next.Node().heightInRecomputeHeap = HeightUnset
aborted = append(aborted, next)
}

rh.heights = make([]*recomputeHeapList, len(rh.heights))
rh.minHeight = 0
rh.maxHeight = 0
rh.numItems = 0
return
}

func (rh *recomputeHeap) len() int {
Expand Down
7 changes: 6 additions & 1 deletion stabilize.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ func (graph *Graph) Stabilize(ctx context.Context) (err error) {
}
}
if err != nil {
graph.recomputeHeap.clear()
aborted := graph.recomputeHeap.clear()
for _, node := range aborted {
for _, ah := range node.Node().onAbortedHandlers {
ah(ctx, err)
}
}
}
if len(immediateRecompute) > 0 {
for _, n := range immediateRecompute {
Expand Down
5 changes: 5 additions & 0 deletions stabilize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ func Test_Stabilize_error(t *testing.T) {
ctx := testContext()
g := New()

var didCallAbortedHandler bool
v0 := Var(g, "hello")
m0 := Map(g, v0, ident)
m1 := Map(g, m0, ident)
m1.Node().OnAborted(func(_ context.Context, err error) {
didCallAbortedHandler = true
})

f0 := Func(g, func(_ context.Context) (string, error) {
return "", fmt.Errorf("this is just a test")
Expand All @@ -80,6 +84,7 @@ func Test_Stabilize_error(t *testing.T) {

testutil.Equal(t, false, g.recomputeHeap.has(m1), "we should clear the recompute heap on error")
testutil.Equal(t, false, g.recomputeHeap.has(f0))
testutil.Equal(t, true, didCallAbortedHandler)
}

func Test_Stabilize_errorHandler(t *testing.T) {
Expand Down

0 comments on commit 330dcae

Please sign in to comment.