diff --git a/node.go b/node.go index 355545c..803aed9 100644 --- a/node.go +++ b/node.go @@ -75,9 +75,13 @@ type Node struct { // onUpdateHandlers are functions that are called when the node updates. // they are added with `OnUpdate(...)`. onUpdateHandlers []func(context.Context) - // onErrorHandlers are functions that are called when the node updates. + // onErrorHandlers are functions that are called when the node errors in stabilization. // they are added with `OnError(...)`. onErrorHandlers []func(context.Context, error) + // onAbortedHandlers are functions that are called when the node is + // pre-empted for update by another node erroring. + // they are added with `OnError(...)`. + onAbortedHandlers []func(context.Context, error) // stabilizeFn is set during initialization and is a shortcut // to the interface sniff for the node for the IStabilize interface. stabilizeFn func(context.Context) error @@ -142,6 +146,14 @@ func (n *Node) OnError(fn func(context.Context, error)) { n.onErrorHandlers = append(n.onErrorHandlers, fn) } +// OnAborted registers an aborted handler. +// +// An aborted handler is called when the stabilize or cutoff +// function for this node is pre-empted by another node erroring. +func (n *Node) OnAborted(fn func(context.Context, error)) { + n.onAbortedHandlers = append(n.onAbortedHandlers, fn) +} + // Label returns a descriptive label for the node or // an empty string if one hasn't been provided. func (n *Node) Label() string { diff --git a/parallel_stabilize.go b/parallel_stabilize.go index a5ffdbc..d712c32 100644 --- a/parallel_stabilize.go +++ b/parallel_stabilize.go @@ -57,7 +57,12 @@ func (graph *Graph) parallelStabilize(ctx context.Context) (err error) { } if err != nil { // clear if there is an error! - graph.recomputeHeap.clear() + aborted := graph.recomputeHeap.clear() + for _, node := range aborted { + for _, ah := range node.Node().onAbortedHandlers { + ah(ctx, err) + } + } } if len(immediateRecompute) > 0 { graph.recomputeHeap.mu.Lock() diff --git a/parallel_stabilize_test.go b/parallel_stabilize_test.go index e68072a..e1b626c 100644 --- a/parallel_stabilize_test.go +++ b/parallel_stabilize_test.go @@ -121,9 +121,13 @@ func Test_ParallelStabilize_error(t *testing.T) { ctx := testContext() g := New() + var didCallAbortedHandler bool v0 := Var(g, "hello") m0 := Map(g, v0, ident) m1 := Map(g, m0, ident) + m1.Node().OnAborted(func(_ context.Context, err error) { + didCallAbortedHandler = true + }) f0 := Func(g, func(ctx context.Context) (string, error) { return "", fmt.Errorf("this is only a test") @@ -140,6 +144,7 @@ func Test_ParallelStabilize_error(t *testing.T) { testutil.Equal(t, false, g.recomputeHeap.has(m1), "we should clear the recompute heap on error") testutil.Equal(t, false, g.recomputeHeap.has(f0)) + testutil.Equal(t, true, didCallAbortedHandler) } func Test_ParallelStabilize_Always(t *testing.T) { diff --git a/recompute_heap.go b/recompute_heap.go index 37adae9..915bdba 100644 --- a/recompute_heap.go +++ b/recompute_heap.go @@ -19,20 +19,23 @@ type recomputeHeap struct { numItems int } -func (rh *recomputeHeap) clear() { +func (rh *recomputeHeap) clear() (aborted []INode) { rh.mu.Lock() defer rh.mu.Unlock() var next INode for rh.numItems > 0 { + aborted = make([]INode, 0, rh.numItems) next, _ = rh.removeMinUnsafe() next.Node().heightInRecomputeHeap = HeightUnset + aborted = append(aborted, next) } rh.heights = make([]*recomputeHeapList, len(rh.heights)) rh.minHeight = 0 rh.maxHeight = 0 rh.numItems = 0 + return } func (rh *recomputeHeap) len() int { diff --git a/stabilize.go b/stabilize.go index 68605f5..597f630 100644 --- a/stabilize.go +++ b/stabilize.go @@ -44,7 +44,12 @@ func (graph *Graph) Stabilize(ctx context.Context) (err error) { } } if err != nil { - graph.recomputeHeap.clear() + aborted := graph.recomputeHeap.clear() + for _, node := range aborted { + for _, ah := range node.Node().onAbortedHandlers { + ah(ctx, err) + } + } } if len(immediateRecompute) > 0 { for _, n := range immediateRecompute { diff --git a/stabilize_test.go b/stabilize_test.go index ffe0fe7..f21b1c9 100644 --- a/stabilize_test.go +++ b/stabilize_test.go @@ -60,9 +60,13 @@ func Test_Stabilize_error(t *testing.T) { ctx := testContext() g := New() + var didCallAbortedHandler bool v0 := Var(g, "hello") m0 := Map(g, v0, ident) m1 := Map(g, m0, ident) + m1.Node().OnAborted(func(_ context.Context, err error) { + didCallAbortedHandler = true + }) f0 := Func(g, func(_ context.Context) (string, error) { return "", fmt.Errorf("this is just a test") @@ -80,6 +84,7 @@ func Test_Stabilize_error(t *testing.T) { testutil.Equal(t, false, g.recomputeHeap.has(m1), "we should clear the recompute heap on error") testutil.Equal(t, false, g.recomputeHeap.has(f0)) + testutil.Equal(t, true, didCallAbortedHandler) } func Test_Stabilize_errorHandler(t *testing.T) {