diff --git a/bind.go b/bind.go index 488a70c..b7c7be6 100644 --- a/bind.go +++ b/bind.go @@ -62,12 +62,19 @@ func BindContext[A, B any](scope Scope, input Incr[A], fn BindContextFunc[A, B]) parents: []INode{bindLeftChange}, }) bind.main = bindMain + // propagate errors to main from the left change node bindLeftChange.n.onErrorHandlers = append(bindLeftChange.n.onErrorHandlers, func(ctx context.Context, err error) { for _, eh := range bindMain.n.onErrorHandlers { eh(ctx, err) } }) + // propagate aborted events to main from the left change node + bindLeftChange.n.onAbortedHandlers = append(bindLeftChange.n.onAbortedHandlers, func(ctx context.Context, err error) { + for _, eh := range bindMain.n.onAbortedHandlers { + eh(ctx, err) + } + }) return bindMain } diff --git a/bind_test.go b/bind_test.go index edfce88..de5bd24 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1179,7 +1179,6 @@ func Test_Bind_cycle(t *testing.T) { nodesThatErrored := make(map[Identifier]INode) var b1 BindIncr[string] - b0v := Var(g, "a") b0v.Node().OnError(func(_ context.Context, err error) { nodesThatErrored[b0v.Node().id] = b0v @@ -1199,10 +1198,7 @@ func Test_Bind_cycle(t *testing.T) { nodesThatErrored[b1v.Node().id] = b1v }) b1 = Bind(g, b1v, func(bs Scope, which string) Incr[string] { - if which == "a" { - return b0 - } - return Return(bs, "bar") + return b0 }) b1.Node().OnError(func(_ context.Context, err error) { nodesThatErrored[b1.Node().id] = b1 @@ -1288,3 +1284,69 @@ func Test_bindLeftChange_RightScopeNodes(t *testing.T) { testutil.Equal(t, 2, len(bindTyped.bind.lhsChange.RightScopeNodes())) } + +func Test_Bind_aborted(t *testing.T) { + ctx := testContext() + g := New( + OptGraphClearRecomputeHeapOnError(true), + ) + nodesThatErrored := make(map[Identifier]INode) + nodesThatAborted := make(map[Identifier]INode) + + hookNode := func(n INode) { + n.Node().OnError(func(_ context.Context, _ error) { + nodesThatErrored[n.Node().id] = n + }) + n.Node().OnAborted(func(_ context.Context, _ error) { + nodesThatAborted[n.Node().id] = n + }) + } + + var b1, b2 BindIncr[string] + b0v := Var(g, "a") + hookNode(b0v) + b0 := Bind(g, b0v, func(bs Scope, which string) Incr[string] { + if which == "a" { + return Return(bs, "foo") + } + return b1 + }) + hookNode(b0) + + b1v := Var(g, "a") + hookNode(b1v) + b1 = Bind(g, b1v, func(bs Scope, which string) Incr[string] { + return b2 + }) + hookNode(b1) + + b2v := Var(g, "a") + hookNode(b2v) + b2 = Bind(g, b2v, func(bs Scope, which string) Incr[string] { + return b0 + }) + hookNode(b2) + + o := MustObserve(g, b2) + hookNode(o) + + err := g.Stabilize(ctx) + testutil.NoError(t, err) + testutil.Equal(t, "foo", o.Value()) + + b0v.Set("b") + + err = g.Stabilize(ctx) + testutil.Error(t, err) + testutil.Equal(t, "foo", o.Value()) + + testutil.Equal(t, 1, len(nodesThatErrored)) + _, ok := nodesThatErrored[b1.Node().id] + testutil.Equal(t, true, ok) + + testutil.Equal(t, 2, len(nodesThatAborted)) + _, ok = nodesThatAborted[b0.Node().id] + testutil.Equal(t, true, ok) + _, ok = nodesThatAborted[b1.Node().id] + testutil.Equal(t, true, ok) +} diff --git a/graph.go b/graph.go index 81b0c24..fea3043 100644 --- a/graph.go +++ b/graph.go @@ -106,7 +106,10 @@ func OptGraphPreallocateSentinelsSize(size int) func(*GraphOptions) { // OptGraphClearRecomputeHeapOnError controls a setting for whether or not the // recompute heap is cleared of nodes on stabilization error. // -// If not provided, the default is to not clear the recompute heap, but leave nodes in place. +// By default the graph will not clear the recompute heap, and instead leave nodes in place. +// +// If this option is provided, and `shouldClear` is `true`, then the recompute heap +// will be cleared on error, and the `OnAborted` handlers of nodes will be called. func OptGraphClearRecomputeHeapOnError(shouldClear bool) func(*GraphOptions) { return func(g *GraphOptions) { g.ClearRecomputeHeapOnError = shouldClear diff --git a/recompute_heap.go b/recompute_heap.go index 915bdba..8190157 100644 --- a/recompute_heap.go +++ b/recompute_heap.go @@ -24,8 +24,8 @@ func (rh *recomputeHeap) clear() (aborted []INode) { defer rh.mu.Unlock() var next INode + aborted = make([]INode, 0, rh.numItems) for rh.numItems > 0 { - aborted = make([]INode, 0, rh.numItems) next, _ = rh.removeMinUnsafe() next.Node().heightInRecomputeHeap = HeightUnset aborted = append(aborted, next)