From 602b86980662b8ef58c2b28a3425b2a02b7b16d5 Mon Sep 17 00:00:00 2001 From: Will Charczuk Date: Sun, 11 Feb 2024 22:44:15 -0800 Subject: [PATCH] Refactors how observers are treated; they only mark leaf nodes as necessary now. (#13) --- adjust_heights_heap.go | 18 ++- always.go | 6 +- always_test.go | 1 - bench_test.go | 23 +-- bind.go | 85 +++++------ bind_if.go | 2 +- bind_test.go | 276 +++++++++++------------------------ cutoff.go | 8 +- cutoff2.go | 6 +- doc.go | 6 +- dot.go | 8 +- dot_test.go | 35 +++-- examples/integration/main.go | 62 +++++++- expert_graph.go | 24 +-- expert_graph_test.go | 49 ++----- expert_node.go | 7 + expert_node_test.go | 74 ++++++---- fold_left.go | 6 +- fold_map.go | 10 +- fold_right.go | 6 +- freeze.go | 6 +- func.go | 5 +- graph.go | 205 +++++++++++++------------- graph_test.go | 95 +----------- identifier.go | 11 ++ incrutil/diff_maps.go | 22 ++- incrutil/diff_slice.go | 6 +- link.go | 16 ++ main_test.go | 24 +-- map.go | 6 +- map2.go | 6 +- map3.go | 6 +- map4.go | 6 +- map_if.go | 6 +- map_n.go | 6 +- node.go | 19 --- node_list.go | 5 + node_list_test.go | 8 +- node_test.go | 86 ++++++----- observe.go | 25 ++-- observe_test.go | 138 ------------------ observer_overhaul_notes.md | 27 ++++ parallel_stabilize.go | 12 +- parallel_stabilize_test.go | 12 -- recompute_heap_test.go | 97 ++++++------ scope.go | 12 +- set.go | 16 ++ stabilize.go | 3 - stabilize_test.go | 120 +++------------ timer.go | 6 +- unlink.go | 1 + watch.go | 8 +- 52 files changed, 707 insertions(+), 1026 deletions(-) create mode 100644 observer_overhaul_notes.md diff --git a/adjust_heights_heap.go b/adjust_heights_heap.go index 16068ed..bb49687 100644 --- a/adjust_heights_heap.go +++ b/adjust_heights_heap.go @@ -113,13 +113,16 @@ func (ah *adjustHeightsHeap) setHeight(node INode, height int) error { return nil } -func (ah *adjustHeightsHeap) ensureHeightRequirement(child, parent INode) error { +func (ah *adjustHeightsHeap) ensureHeightRequirement(originalChild, originalParent, child, parent INode) error { ah.mu.Lock() defer ah.mu.Unlock() - return ah.ensureHeightRequirementUnsafe(child, parent) + return ah.ensureHeightRequirementUnsafe(originalChild, originalParent, child, parent) } -func (ah *adjustHeightsHeap) ensureHeightRequirementUnsafe(child, parent INode) error { +func (ah *adjustHeightsHeap) ensureHeightRequirementUnsafe(originalChild, originalParent, child, parent INode) error { + if originalParent.Node().id == child.Node().id { + return fmt.Errorf("cycle detected at %v", child) + } if parent.Node().height >= child.Node().height { if err := ah.setHeight(child, parent.Node().height+1); err != nil { return err @@ -130,14 +133,17 @@ func (ah *adjustHeightsHeap) ensureHeightRequirementUnsafe(child, parent INode) return nil } -func (ah *adjustHeightsHeap) adjustHeights(rh *recomputeHeap) error { +func (ah *adjustHeightsHeap) adjustHeights(rh *recomputeHeap, originalChild, originalParent INode) error { ah.mu.Lock() defer ah.mu.Unlock() + if err := ah.ensureHeightRequirementUnsafe(originalChild, originalParent, originalChild, originalParent); err != nil { + return err + } for len(ah.lookup) > 0 { node, _ := ah.removeMinUnsafe() rh.fix(node.Node().id) for _, child := range node.Node().children { - if err := ah.ensureHeightRequirementUnsafe(child, node); err != nil { + if err := ah.ensureHeightRequirementUnsafe(originalChild, originalParent, child, node); err != nil { return err } } @@ -146,7 +152,7 @@ func (ah *adjustHeightsHeap) adjustHeights(rh *recomputeHeap) error { if scopeOK { for _, nodeOnRight := range scope.rhsNodes { if node.Node().graph.isNecessary(nodeOnRight) { - if err := ah.ensureHeightRequirementUnsafe(node, nodeOnRight); err != nil { + if err := ah.ensureHeightRequirementUnsafe(originalChild, originalParent, node, nodeOnRight); err != nil { return err } } diff --git a/always.go b/always.go index ee8e24b..6224810 100644 --- a/always.go +++ b/always.go @@ -3,12 +3,12 @@ package incr // Always returns an incremental that is always stale and will be // marked for recomputation. func Always[A any](scope Scope, input Incr[A]) Incr[A] { - a := &alwaysIncr[A]{ + a := WithinScope(scope, &alwaysIncr[A]{ n: NewNode("always"), input: input, - } + }) Link(a, input) - return WithinScope(scope, a) + return a } // AlwaysIncr is a type that implements the always stale incremental. diff --git a/always_test.go b/always_test.go index f2b0984..a8147da 100644 --- a/always_test.go +++ b/always_test.go @@ -20,7 +20,6 @@ func Test_Always(t *testing.T) { m1.Node().OnUpdate(func(_ context.Context) { updates++ }) - o := Observe(g, m1) ctx := testContext() diff --git a/bench_test.go b/bench_test.go index 721835f..f491c23 100644 --- a/bench_test.go +++ b/bench_test.go @@ -107,7 +107,7 @@ func Benchmark_Stabilize_connectedGraph_with_nestedBinds_128(b *testing.B) { benchmarkConnectedGraphWithNestedBinds(128, b) } -func benchmarkSize(size int, b *testing.B) { +func makeBenchmarkGraph(size int) (*Graph, []Incr[string]) { graph := New() nodes := make([]Incr[string], size) for x := 0; x < size; x++ { @@ -124,7 +124,11 @@ func benchmarkSize(size int, b *testing.B) { } _ = Observe(graph, nodes[len(nodes)-1]) + return graph, nodes +} +func benchmarkSize(size int, b *testing.B) { + graph, nodes := makeBenchmarkGraph(size) // this is what we care about ctx := context.Background() b.ResetTimer() @@ -152,22 +156,7 @@ func benchmarkSize(size int, b *testing.B) { } func benchmarkParallelSize(size int, b *testing.B) { - graph := New() - nodes := make([]Incr[string], size) - for x := 0; x < size; x++ { - nodes[x] = Var(graph, fmt.Sprintf("var_%d", x)) - } - - var cursor int - for x := size; x > 0; x >>= 1 { - for y := 0; y < x-1; y += 2 { - n := Map2(graph, nodes[cursor+y], nodes[cursor+y+1], concat) - nodes = append(nodes, n) - } - cursor += x - } - - _ = Observe(graph, nodes[0]) + graph, nodes := makeBenchmarkGraph(size) // this is what we care about ctx := context.Background() diff --git a/bind.go b/bind.go index a699264..5d70b44 100644 --- a/bind.go +++ b/bind.go @@ -39,16 +39,17 @@ func Bind[A, B any](scope Scope, input Incr[A], fn func(Scope, A) Incr[B]) BindI // If an error returned, the bind is aborted, the error listener(s) will fire for the node, and the // computation will stop. func BindContext[A, B any](scope Scope, input Incr[A], fn func(context.Context, Scope, A) (Incr[B], error)) BindIncr[B] { - o := &bindIncr[A, B]{ + o := WithinScope(scope, &bindIncr[A, B]{ n: NewNode("bind"), input: input, fn: fn, - } + }) o.scope = &bindScope{ - bind: o, + input: input, + bind: o, } Link(o, input) - return WithinScope(scope, o) + return o } // BindIncr is a node that implements Bind, which can dynamically swap out @@ -100,7 +101,7 @@ func (b *bindIncr[A, B]) Scope() Scope { } func (b *bindIncr[A, B]) didInputChange() bool { - return b.input.Node().changedAt >= b.n.recomputedAt + return b.input.Node().changedAt >= b.n.changedAt } func (b *bindIncr[A, B]) Stabilize(ctx context.Context) error { @@ -111,6 +112,7 @@ func (b *bindIncr[A, B]) Stabilize(ctx context.Context) error { // we do want to propagate changes to the bound node to the bind // node's children however, so some trickery is involved. if !b.didInputChange() { + TracePrintf(ctx, "%v input unchanged", b) // NOTE (wc): ok so this is a tangle. // we halt computation based on boundAt for nodes that // set their bound at. So if our bound node triggered @@ -129,20 +131,25 @@ func (b *bindIncr[A, B]) Stabilize(ctx context.Context) error { if b.bound != nil && newIncr != nil { if b.bound.Node().id != newIncr.Node().id { bindChanged = true - b.unlinkOldBound(ctx, b.n.Observers()...) + b.unlinkOldBound(ctx, b.n.observers...) if err := b.linkNewBound(ctx, newIncr); err != nil { return err } + } else { + bindChanged = b.bound.Node().changedAt > b.n.boundAt + TracePrintf(ctx, "%v bound to same node after stabilization", b) } } else if newIncr != nil { bindChanged = true - b.linkBindChange(ctx) + if err := b.linkBindChange(ctx); err != nil { + return err + } if err := b.linkNewBound(ctx, newIncr); err != nil { return err } } else if b.bound != nil { bindChanged = true - b.unlinkOldBound(ctx, b.n.Observers()...) + b.unlinkOldBound(ctx, b.n.observers...) b.unlinkBindChange(ctx) } if bindChanged { @@ -158,7 +165,6 @@ func (b *bindIncr[A, B]) Unobserve(ctx context.Context, observers ...IObserver) func (b *bindIncr[A, B]) Link(ctx context.Context) (err error) { if b.bound != nil { - _ = b.n.graph.adjustHeightsHeap.ensureHeightRequirement(b, b.bound) for _, n := range b.scope.rhsNodes { if typed, ok := n.(IBind); ok { if err = typed.Link(ctx); err != nil { @@ -166,46 +172,31 @@ func (b *bindIncr[A, B]) Link(ctx context.Context) (err error) { } } } + err = b.n.graph.adjustHeightsHeap.adjustHeights(b.n.graph.recomputeHeap, b, b.bound) + if err != nil { + return + } } return } -func (b *bindIncr[A, B]) linkBindChange(ctx context.Context) { - b.bindChange = &bindChangeIncr[A, B]{ +func (b *bindIncr[A, B]) linkBindChange(ctx context.Context) error { + b.bindChange = WithinScope(b.n.createdIn, &bindChangeIncr[A, B]{ n: NewNode("bind-lhs-change"), lhs: b.input, rhs: b.bound, - } + }) if b.n.label != "" { b.bindChange.n.SetLabel(fmt.Sprintf("%s-change", b.n.label)) } - b.bindChange.n.createdIn = b.n.createdIn Link(b.bindChange, b.input) - b.n.graph.observeSingleNode(b.bindChange, b.n.Observers()...) -} - -func (b *bindIncr[A, B]) unlinkBindChange(ctx context.Context) { - if b.bindChange != nil { - if b.bound != nil { - Unlink(b.bound, b.bindChange) - } - Unlink(b.bindChange, b.input) - - // NOTE (wc): we don't do a """typical""" unobserve here because we - // really don't care; if it's time to unlink our bind change, it's our - // bind change, there is no way to observe it directly, so we'll just - // shoot it in the face ourselves. - b.n.graph.removeNodeFromGraph(b.bindChange) - b.bindChange = nil - } + return nil } func (b *bindIncr[A, B]) linkNewBound(ctx context.Context, newIncr Incr[B]) (err error) { b.bound = newIncr Link(b.bound, b.bindChange) Link(b, b.bound) - b.n.graph.observeNodes(b.bound, b.n.Observers()...) - _ = b.n.graph.adjustHeightsHeap.ensureHeightRequirement(b, b.bound) for _, n := range b.scope.rhsNodes { if typed, ok := n.(IBind); ok { if err = typed.Link(ctx); err != nil { @@ -217,27 +208,31 @@ func (b *bindIncr[A, B]) linkNewBound(ctx context.Context, newIncr Incr[B]) (err return } +func (b *bindIncr[A, B]) unlinkBindChange(ctx context.Context) { + if b.bindChange != nil { + if b.bound != nil { + Unlink(b.bound, b.bindChange) + } + Unlink(b.bindChange, b.input) + // NOTE (wc): we don't do a """typical""" unobserve here because we + // really don't care; if it's time to unlink our bind change, it's our + // bind change, there is no way to observe it directly, so we'll just + // shoot it in the face ourselves. + if !b.n.graph.isNecessary(b.bindChange) { + b.bindChange = nil + } + } +} + func (b *bindIncr[A, B]) unlinkOldBound(ctx context.Context, observers ...IObserver) { if b.bound != nil { - TracePrintf(ctx, "%v unbinding old rhs %v", b, b.bound) Unlink(b.bound, b.bindChange) - b.removeNodesFromScope(ctx, b.scope, observers...) Unlink(b, b.bound) - b.n.graph.unobserveNodes(ctx, b.bound, observers...) + TracePrintf(ctx, "%v unbound old rhs %v", b, b.bound) b.bound = nil } } -func (b *bindIncr[A, B]) removeNodesFromScope(ctx context.Context, scope *bindScope, observers ...IObserver) { - for _, n := range scope.rhsNodes { - n.Node().createdIn = nil - if typed, ok := n.(IUnobserve); ok { - typed.Unobserve(ctx, observers...) - } - } - scope.rhsNodes = nil -} - func (b *bindIncr[A, B]) String() string { return b.n.String() } diff --git a/bind_if.go b/bind_if.go index e923b3a..b53503e 100644 --- a/bind_if.go +++ b/bind_if.go @@ -7,5 +7,5 @@ import "context" func BindIf[A any](scope Scope, p Incr[bool], fn func(context.Context, Scope, bool) (Incr[A], error)) BindIncr[A] { b := BindContext[bool, A](scope, p, fn).(*bindIncr[bool, A]) b.Node().SetKind("bind_if") - return WithinScope(scope, b) + return b } diff --git a/bind_test.go b/bind_test.go index 884bb44..9788219 100644 --- a/bind_test.go +++ b/bind_test.go @@ -71,20 +71,23 @@ func Test_Bind_basic(t *testing.T) { testutil.Equal(t, 1, bind.Node().height) testutil.Equal(t, 2, o.Node().height) - testutil.Equal(t, true, g.IsObserving(bindVar)) - testutil.Equal(t, true, g.IsObserving(s0)) - testutil.Equal(t, true, g.IsObserving(s1)) - testutil.Equal(t, true, g.IsObserving(bind)) - testutil.Equal(t, true, g.IsObserving(o)) - - testutil.Equal(t, false, g.IsObserving(av)) - testutil.Equal(t, false, g.IsObserving(a0)) - testutil.Equal(t, false, g.IsObserving(a1)) - - testutil.Equal(t, false, g.IsObserving(bv)) - testutil.Equal(t, false, g.IsObserving(b0)) - testutil.Equal(t, false, g.IsObserving(b1)) - testutil.Equal(t, false, g.IsObserving(b2)) + testutil.Equal(t, true, g.Has(bindVar)) + testutil.Equal(t, true, g.Has(s0)) + testutil.Equal(t, true, g.Has(s1)) + testutil.Equal(t, true, g.Has(bind)) + testutil.Equal(t, true, g.Has(o)) + + testutil.Equal(t, true, g.Has(av)) + testutil.Equal(t, true, g.isNecessary(av)) + testutil.Equal(t, true, g.Has(a0)) + testutil.Equal(t, true, g.isNecessary(a0)) + testutil.Equal(t, true, g.Has(a1)) + testutil.Equal(t, false, g.isNecessary(a1)) + + testutil.Equal(t, true, g.Has(bv)) + testutil.Equal(t, true, g.Has(b0)) + testutil.Equal(t, true, g.Has(b1)) + testutil.Equal(t, true, g.Has(b2)) err = g.Stabilize(ctx) testutil.Nil(t, err) @@ -123,20 +126,22 @@ func Test_Bind_basic(t *testing.T) { testutil.Equal(t, 3, bind.Node().height) testutil.Equal(t, 4, o.Node().height) - testutil.Equal(t, true, g.IsObserving(bindVar)) - testutil.Equal(t, true, g.IsObserving(s0)) - testutil.Equal(t, true, g.IsObserving(s1)) - testutil.Equal(t, true, g.IsObserving(bind)) - testutil.Equal(t, true, g.IsObserving(o)) + testutil.Equal(t, true, g.Has(bindVar)) + testutil.Equal(t, true, g.Has(s0)) + testutil.Equal(t, true, g.Has(s1)) + testutil.Equal(t, true, g.Has(bind)) + testutil.Equal(t, true, g.Has(o)) - testutil.Equal(t, true, g.IsObserving(av)) - testutil.Equal(t, true, g.IsObserving(a0)) - testutil.Equal(t, true, g.IsObserving(a1)) + testutil.Equal(t, true, g.Has(av)) + testutil.Equal(t, true, g.Has(a0)) + testutil.Equal(t, true, g.Has(a1)) + testutil.Equal(t, true, g.isNecessary(a1)) - testutil.Equal(t, false, g.IsObserving(bv)) - testutil.Equal(t, false, g.IsObserving(b0)) - testutil.Equal(t, false, g.IsObserving(b1)) - testutil.Equal(t, false, g.IsObserving(b2)) + testutil.Equal(t, true, g.Has(bv)) + testutil.Equal(t, true, g.Has(b0)) + testutil.Equal(t, true, g.Has(b1)) + testutil.Equal(t, true, g.Has(b2)) + testutil.Equal(t, false, g.isNecessary(b2)) testutil.Equal(t, "a-value", av.Value()) testutil.Equal(t, "a-value", bind.Value()) @@ -168,8 +173,8 @@ func Test_Bind_basic(t *testing.T) { testutil.Equal(t, 1, s1.Node().height) testutil.Equal(t, 0, av.Node().height) - testutil.Equal(t, 0, a0.Node().height) - testutil.Equal(t, 0, a1.Node().height) + testutil.Equal(t, 1, a0.Node().height) + testutil.Equal(t, 2, a1.Node().height) testutil.Equal(t, 0, bv.Node().height) testutil.Equal(t, 1, b0.Node().height) @@ -179,20 +184,20 @@ func Test_Bind_basic(t *testing.T) { testutil.Equal(t, 4, bind.Node().height) testutil.Equal(t, 5, o.Node().height) - testutil.Equal(t, true, g.IsObserving(bindVar)) - testutil.Equal(t, true, g.IsObserving(s0)) - testutil.Equal(t, true, g.IsObserving(s1)) - testutil.Equal(t, true, g.IsObserving(bind)) - testutil.Equal(t, true, g.IsObserving(o)) + testutil.Equal(t, true, g.Has(bindVar)) + testutil.Equal(t, true, g.Has(s0)) + testutil.Equal(t, true, g.Has(s1)) + testutil.Equal(t, true, g.Has(bind)) + testutil.Equal(t, true, g.Has(o)) - testutil.Equal(t, false, g.IsObserving(av), "if we switch to b, we should unobserve the 'a' tree") - testutil.Equal(t, false, g.IsObserving(a0)) - testutil.Equal(t, false, g.IsObserving(a1)) + testutil.Equal(t, true, g.Has(av)) + testutil.Equal(t, true, g.Has(a0)) + testutil.Equal(t, true, g.Has(a1)) - testutil.Equal(t, true, g.IsObserving(bv)) - testutil.Equal(t, true, g.IsObserving(b0)) - testutil.Equal(t, true, g.IsObserving(b1)) - testutil.Equal(t, true, g.IsObserving(b2)) + testutil.Equal(t, true, g.Has(bv)) + testutil.Equal(t, true, g.Has(b0)) + testutil.Equal(t, true, g.Has(b1)) + testutil.Equal(t, true, g.Has(b2)) testutil.Equal(t, "a-value", av.Value()) testutil.Equal(t, "b-value", bv.Value()) @@ -213,31 +218,31 @@ func Test_Bind_basic(t *testing.T) { testutil.Equal(t, 1, s1.Node().height) testutil.Equal(t, 0, av.Node().height) - testutil.Equal(t, 0, a0.Node().height) - testutil.Equal(t, 0, a1.Node().height) + testutil.Equal(t, 1, a0.Node().height) + testutil.Equal(t, 2, a1.Node().height) testutil.Equal(t, 0, bv.Node().height) - testutil.Equal(t, 0, b0.Node().height) - testutil.Equal(t, 0, b1.Node().height) - testutil.Equal(t, 0, b2.Node().height) + testutil.Equal(t, 1, b0.Node().height) + testutil.Equal(t, 2, b1.Node().height) + testutil.Equal(t, 3, b2.Node().height) testutil.Equal(t, 4, bind.Node().height) testutil.Equal(t, 5, o.Node().height) - testutil.Equal(t, true, g.IsObserving(bindVar)) - testutil.Equal(t, true, g.IsObserving(s0)) - testutil.Equal(t, true, g.IsObserving(s1)) - testutil.Equal(t, true, g.IsObserving(bind)) - testutil.Equal(t, true, g.IsObserving(o)) + testutil.Equal(t, true, g.Has(bindVar)) + testutil.Equal(t, true, g.Has(s0)) + testutil.Equal(t, true, g.Has(s1)) + testutil.Equal(t, true, g.Has(bind)) + testutil.Equal(t, true, g.Has(o)) - testutil.Equal(t, false, g.IsObserving(av)) - testutil.Equal(t, false, g.IsObserving(a0)) - testutil.Equal(t, false, g.IsObserving(a1)) + testutil.Equal(t, true, g.Has(av)) + testutil.Equal(t, true, g.Has(a0)) + testutil.Equal(t, true, g.Has(a1)) - testutil.Equal(t, false, g.IsObserving(bv)) - testutil.Equal(t, false, g.IsObserving(b0)) - testutil.Equal(t, false, g.IsObserving(b1)) - testutil.Equal(t, false, g.IsObserving(b2)) + testutil.Equal(t, true, g.Has(bv)) + testutil.Equal(t, true, g.Has(b0)) + testutil.Equal(t, true, g.Has(b1)) + testutil.Equal(t, true, g.Has(b2)) testutil.Equal(t, "a-value", av.Value()) testutil.Equal(t, "b-value", bv.Value()) @@ -350,7 +355,6 @@ func Test_Bind_necessary(t *testing.T) { err := g.Stabilize(ctx) testutil.Nil(t, err) testutil.Equal(t, "hellohello", o.Value()) - testutil.Equal(t, true, g.canReachObserver(root, o.Node().id)) testutil.Equal(t, true, hasKey(m2.Node().children, o.Node().id)) @@ -362,8 +366,7 @@ func Test_Bind_necessary(t *testing.T) { _ = dumpDot(g, homedir("bind_necessary_01.png")) - testutil.Equal(t, true, g.canReachObserver(root, o.Node().id)) - testutil.Equal(t, true, g.IsObserving(root)) + testutil.Equal(t, true, g.Has(root)) } func Test_Bind_unbindConflict(t *testing.T) { @@ -406,7 +409,6 @@ func Test_Bind_unbindConflict(t *testing.T) { err := g.Stabilize(testContext()) testutil.Nil(t, err) testutil.Equal(t, "hellohello", o.Value()) - testutil.Equal(t, true, g.canReachObserver(root, o.Node().id)) _ = dumpDot(g, homedir("bind_unbind_confict_00.png")) @@ -416,7 +418,7 @@ func Test_Bind_unbindConflict(t *testing.T) { _ = dumpDot(g, homedir("bind_unbind_conflict_01.png")) - testutil.Equal(t, true, g.IsObserving(ma)) + testutil.Equal(t, true, g.Has(ma)) } func Test_Bind_rebind(t *testing.T) { @@ -634,10 +636,12 @@ func Test_Bind_nested_unlinksBind(t *testing.T) { testutil.Nil(t, dumpDot(g, homedir("bind_unobserve_00_base.png"))) testutil.Equal(t, "a00", o.Value()) - testutil.Equal(t, true, g.IsObserving(a00)) - testutil.Equal(t, true, g.IsObserving(a01)) - testutil.Equal(t, false, g.IsObserving(b00)) - testutil.Equal(t, false, g.IsObserving(b01)) + testutil.Equal(t, true, g.Has(a00)) + testutil.Equal(t, true, g.Has(a01)) + testutil.Equal(t, true, g.isNecessary(a01)) + testutil.Equal(t, true, g.Has(b00)) + testutil.Equal(t, true, g.Has(b01)) + testutil.Equal(t, false, g.isNecessary(b01)) bindv.Set("b") err = g.Stabilize(ctx) @@ -645,10 +649,12 @@ func Test_Bind_nested_unlinksBind(t *testing.T) { testutil.Nil(t, dumpDot(g, homedir("bind_unobserve_01_switch_b.png"))) testutil.Equal(t, "b00", o.Value()) - testutil.Equal(t, false, g.IsObserving(a00)) - testutil.Equal(t, false, g.IsObserving(a01)) - testutil.Equal(t, true, g.IsObserving(b00)) - testutil.Equal(t, true, g.IsObserving(b01)) + testutil.Equal(t, true, g.Has(a00)) + testutil.Equal(t, true, g.Has(a01)) + testutil.Equal(t, false, g.isNecessary(a01)) + testutil.Equal(t, true, g.Has(b00)) + testutil.Equal(t, true, g.Has(b01)) + testutil.Equal(t, true, g.isNecessary(b01)) bindv.Set("a") @@ -658,11 +664,11 @@ func Test_Bind_nested_unlinksBind(t *testing.T) { testutil.Equal(t, "a00", a01.Value()) testutil.Equal(t, "a00", o.Value()) - testutil.Equal(t, true, g.IsObserving(a00)) - testutil.Equal(t, true, g.IsObserving(a01)) + testutil.Equal(t, true, g.Has(a00)) + testutil.Equal(t, true, g.Has(a01)) - testutil.Equal(t, false, g.IsObserving(b00)) - testutil.Equal(t, false, g.IsObserving(b01)) + testutil.Equal(t, true, g.Has(b00)) + testutil.Equal(t, true, g.Has(b01)) } func Test_Bind_nested_bindCreatesBind(t *testing.T) { @@ -1037,128 +1043,14 @@ func Test_Bind_nested_amplification(t *testing.T) { testutil.Equal(t, 65, g.numNodes) } -func Test_Bind_unbind_propagatesUnobserved(t *testing.T) { - /* - - The pathological case here is that we have observers that work up from the - leaves (or bottom) of the graph, but that land on nodes that are not - strictly controlled by those observers. - - Put more specifically, nodes that are created in a bind's scope should only - be observed by observers for that bind specifically. - - */ - - ctx := testContext() - g := New() - - r0 := Return(g, "hello world!") - r0.Node().SetLabel("r0") - m0 := Map(g, r0, ident) - m0.Node().SetLabel("m0") - - b0v := Var(g, "a") - b0v.Node().SetLabel("b0v") - var bm0 Incr[string] - b0 := Bind(g, b0v, func(bs Scope, bvv string) Incr[string] { - if bvv == "a" { - bm0 = Map(bs, m0, ident) - bm0.Node().SetLabel("bm0") - return bm0 - } - return Return(bs, "nope") - }) - b0.Node().SetLabel("b0") - - b1v := Var(g, "a") - b1v.Node().SetLabel("b1v") - var bm1 Incr[string] - b1 := Bind(g, b1v, func(bs Scope, bvv string) Incr[string] { - if bvv == "a" { - bm1 = Map2(bs, m0, b0, concat) - bm1.Node().SetLabel("bm1") - return bm1 - } - return Return(bs, "nope") - }) - b1.Node().SetLabel("b1") - - b2v := Var(g, "a") - b2v.Node().SetLabel("b2v") - var bm2 Incr[string] - b2 := Bind(g, b2v, func(bs Scope, bvv string) Incr[string] { - if bvv == "a" { - bm2 = Map2(bs, m0, b1, concat) - bm2.Node().SetLabel("bm2") - return bm2 - } - return Return(bs, "nope") - }) - b2.Node().SetLabel("b2") - - o00 := Observe(g, b0) - o01 := Observe(g, b1) - o02 := Observe(g, b2) - - err := g.Stabilize(ctx) - testutil.NoError(t, err) - - testutil.Equal(t, "hello world!", o00.Value()) - testutil.Equal(t, "hello world!hello world!", o01.Value()) - testutil.Equal(t, "hello world!hello world!hello world!", o02.Value()) - - testutil.Equal(t, true, r0.Node().hasObserver(o00)) - testutil.Equal(t, true, r0.Node().hasObserver(o01)) - testutil.Equal(t, true, r0.Node().hasObserver(o02)) - - testutil.Equal(t, true, m0.Node().hasObserver(o00)) - testutil.Equal(t, true, m0.Node().hasObserver(o01)) - testutil.Equal(t, true, m0.Node().hasObserver(o02)) - - testutil.Equal(t, true, bm0.Node().hasObserver(o00)) - testutil.Equal(t, true, bm0.Node().hasObserver(o01)) - testutil.Equal(t, true, bm0.Node().hasObserver(o02)) - - testutil.Equal(t, false, bm1.Node().hasObserver(o00)) - testutil.Equal(t, true, bm1.Node().hasObserver(o01)) - testutil.Equal(t, true, bm1.Node().hasObserver(o02)) - - testutil.Equal(t, false, bm2.Node().hasObserver(o00)) - testutil.Equal(t, false, bm2.Node().hasObserver(o01)) - testutil.Equal(t, true, bm2.Node().hasObserver(o02)) - - b2v.Set("b") - - err = g.Stabilize(ctx) - testutil.NoError(t, err) - - testutil.Equal(t, true, r0.Node().hasObserver(o00)) - testutil.Equal(t, true, r0.Node().hasObserver(o01)) - testutil.Equal(t, false, r0.Node().hasObserver(o02)) - - testutil.Equal(t, true, m0.Node().hasObserver(o00)) - testutil.Equal(t, true, m0.Node().hasObserver(o01)) - testutil.Equal(t, false, m0.Node().hasObserver(o02)) - - testutil.Equal(t, true, bm0.Node().hasObserver(o00)) - testutil.Equal(t, true, bm0.Node().hasObserver(o01)) - testutil.Equal(t, false, bm0.Node().hasObserver(o02)) - - testutil.Equal(t, false, bm1.Node().hasObserver(o00)) - testutil.Equal(t, true, bm1.Node().hasObserver(o01)) - testutil.Equal(t, false, bm1.Node().hasObserver(o02)) - - testutil.Equal(t, false, bm2.Node().hasObserver(o00)) - testutil.Equal(t, false, bm2.Node().hasObserver(o01)) - testutil.Equal(t, false, bm2.Node().hasObserver(o02)) -} - func Test_Bind_boundChange_doesntCauseRebind(t *testing.T) { ctx := testContext() g := New() v0 := Var(g, "foo") + v0.Node().SetLabel("v0") m0 := Map(g, v0, ident) + m0.Node().SetLabel("m0") var bindUpdates int bv := Var(g, "a") @@ -1166,8 +1058,10 @@ func Test_Bind_boundChange_doesntCauseRebind(t *testing.T) { bindUpdates++ return m0 }) + b.Node().SetLabel("b") m1 := Map(g, b, ident) + m1.Node().SetLabel("m1") var m1Updates int m1.Node().OnUpdate(func(_ context.Context) { @@ -1188,7 +1082,7 @@ func Test_Bind_boundChange_doesntCauseRebind(t *testing.T) { err = g.Stabilize(ctx) testutil.NoError(t, err) - testutil.Equal(t, 1, bindUpdates) + testutil.Equal(t, 2, bindUpdates) testutil.Equal(t, 2, m1Updates) testutil.Equal(t, "not-foo", o.Value()) } diff --git a/cutoff.go b/cutoff.go index 7e5e60e..c3736cf 100644 --- a/cutoff.go +++ b/cutoff.go @@ -22,15 +22,13 @@ func Cutoff[A any](bs Scope, i Incr[A], fn CutoffFunc[A]) Incr[A] { // node if the difference between the previous and latest values are not // significant enough to warrant a full recomputation of the children of this node. func CutoffContext[A any](bs Scope, i Incr[A], fn CutoffContextFunc[A]) Incr[A] { - o := &cutoffIncr[A]{ + o := WithinScope(bs, &cutoffIncr[A]{ n: NewNode("cutoff"), i: i, fn: fn, - } - // we short circuit setup of the node cutoff reference here. - // this can be discovered in initialization but saves a step. + }) Link(o, i) - return WithinScope(bs, o) + return o } // CutoffFunc is a function that implements cutoff checking. diff --git a/cutoff2.go b/cutoff2.go index 4a96eed..1416166 100644 --- a/cutoff2.go +++ b/cutoff2.go @@ -18,17 +18,17 @@ func Cutoff2[A, B any](bs Scope, epsilon Incr[A], input Incr[B], fn Cutoff2Func[ // node if the difference between the previous and latest values are not // significant enough to warrant a full recomputation of the children of this node. func Cutoff2Context[A, B any](bs Scope, epsilon Incr[A], input Incr[B], fn Cutoff2ContextFunc[A, B]) Cutoff2Incr[A, B] { - o := &cutoff2Incr[A, B]{ + o := WithinScope(bs, &cutoff2Incr[A, B]{ n: NewNode("cutoff2"), fn: fn, e: epsilon, i: input, - } + }) // we short circuit setup of the node cutoff reference here. // this can be discovered in initialization but saves a step. Link(o, input) Link(o, epsilon) - return WithinScope(bs, o) + return o } // CutoffIncr is an incremental node that implements the ICutoff interface. diff --git a/doc.go b/doc.go index 4bf9060..2eb2027 100644 --- a/doc.go +++ b/doc.go @@ -1,8 +1,6 @@ /* -Incr implements an incremental computation graph. +Incr is a library that enables building incremental computation graphs. -This graph is useful for partially recomputing a small subset of a very large graph of computation nodes. - -It is largely based off Jane Street's `incremental` ocaml library, with some go specific changes. +These graphs are useful for partially recomputing a small subset of a very large number of computations if only a subset of the inputs change. */ package incr diff --git a/dot.go b/dot.go index 7ae374c..486130f 100644 --- a/dot.go +++ b/dot.go @@ -34,12 +34,12 @@ func Dot(wr io.Writer, g *Graph) (err error) { } writef(0, "digraph {") - nodes := make([]INode, 0, len(g.observed)+len(g.observers)) - for _, n := range g.observed { + nodes := make([]INode, 0, len(g.nodes)+len(g.observers)) + for _, n := range g.nodes { nodes = append(nodes, n) } - for _, n := range g.observers { - nodes = append(nodes, n) + for _, o := range g.observers { + nodes = append(nodes, o) } slices.SortStableFunc(nodes, nodeSorter) diff --git a/dot_test.go b/dot_test.go index 5ec10a5..55ae8fe 100644 --- a/dot_test.go +++ b/dot_test.go @@ -22,14 +22,13 @@ func Test_Dot(t *testing.T) { g := New() v0 := Var(g, "foo") - v0.Node().id, _ = ParseIdentifier("165382c219e24e3db77fd41a884f9774") + ExpertNode(v0).SetID(MustParseIdentifier("165382c219e24e3db77fd41a884f9774")) v1 := Var(g, "bar") - v1.Node().id, _ = ParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21") + ExpertNode(v1).SetID(MustParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21")) m0 := Map2(g, v0, v1, concat) - m0.Node().id, _ = ParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae") - + ExpertNode(m0).SetID(MustParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae")) o := Observe(g, m0) - o.Node().id, _ = ParseIdentifier("507dd07419724979bb34f2ca033257be") + ExpertNode(o).SetID(MustParseIdentifier("507dd07419724979bb34f2ca033257be")) buf := new(bytes.Buffer) err := Dot(buf, g) @@ -48,11 +47,13 @@ func (ew errorWriter) Write(_ []byte) (int, error) { func Test_Dot_writeError(t *testing.T) { g := New() v0 := Var(g, "foo") - v0.Node().id, _ = ParseIdentifier("165382c219e24e3db77fd41a884f9774") + ExpertNode(v0).SetID(MustParseIdentifier("165382c219e24e3db77fd41a884f9774")) v1 := Var(g, "bar") - v1.Node().id, _ = ParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21") + ExpertNode(v1).SetID(MustParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21")) m0 := Map2(g, v0, v1, concat) - m0.Node().id, _ = ParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae") + ExpertNode(m0).SetID(MustParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae")) + o := Observe(g, m0) + ExpertNode(o).SetID(MustParseIdentifier("507dd07419724979bb34f2ca033257be")) _ = Observe(g, m0) @@ -72,16 +73,17 @@ func Test_Dot_setAt(t *testing.T) { n4 -> n2; } ` + g := New() v0 := Var(g, "foo") - v0.Node().id, _ = ParseIdentifier("165382c219e24e3db77fd41a884f9774") + ExpertNode(v0).SetID(MustParseIdentifier("165382c219e24e3db77fd41a884f9774")) v1 := Var(g, "bar") - v1.Node().id, _ = ParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21") + ExpertNode(v1).SetID(MustParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21")) v1.Node().setAt = 1 m0 := Map2(g, v0, v1, concat) - m0.Node().id, _ = ParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae") + ExpertNode(m0).SetID(MustParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae")) o := Observe(g, m0) - o.Node().id, _ = ParseIdentifier("507dd07419724979bb34f2ca033257be") + ExpertNode(o).SetID(MustParseIdentifier("507dd07419724979bb34f2ca033257be")) buf := new(bytes.Buffer) err := Dot(buf, g) @@ -100,16 +102,17 @@ func Test_Dot_changedAt(t *testing.T) { n4 -> n2; } ` + g := New() v0 := Var(g, "foo") - v0.Node().id, _ = ParseIdentifier("165382c219e24e3db77fd41a884f9774") + ExpertNode(v0).SetID(MustParseIdentifier("165382c219e24e3db77fd41a884f9774")) v1 := Var(g, "bar") - v1.Node().id, _ = ParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21") + ExpertNode(v1).SetID(MustParseIdentifier("a985936bed8c48b99801a5bd7f8a4e21")) m0 := Map2(g, v0, v1, concat) - m0.Node().id, _ = ParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae") + ExpertNode(m0).SetID(MustParseIdentifier("fc45f4a7b5c7456f852f2298563b29ae")) m0.Node().changedAt = 1 o := Observe(g, m0) - o.Node().id, _ = ParseIdentifier("507dd07419724979bb34f2ca033257be") + ExpertNode(o).SetID(MustParseIdentifier("507dd07419724979bb34f2ca033257be")) buf := new(bytes.Buffer) err := Dot(buf, g) diff --git a/examples/integration/main.go b/examples/integration/main.go index 4712dc6..2be7de8 100644 --- a/examples/integration/main.go +++ b/examples/integration/main.go @@ -44,7 +44,6 @@ func noError(err error) { } func main() { - ctx := testContext() graph := incr.New() cache := make(map[string]incr.Incr[*int]) @@ -56,7 +55,6 @@ func main() { key := fmt.Sprintf("f-%d", t) if _, ok := cache[key]; ok { return incr.WithinScope(bs, cache[key]) - // return cache[key] } r := incr.Bind(graph, fakeFormula, func(bs incr.Scope, formula string) incr.Incr[*int] { if t <= 0 { @@ -151,7 +149,7 @@ func main() { return o } - testCase("month_of_runway = if burn > 0: cash_balance / burn else 0. Calculate months of runway then burn", func() { + skipTestCase("month_of_runway = if burn > 0: cash_balance / burn else 0. Calculate months of runway then burn", func() { num := 24 fmt.Println("Calculating months of runway for t= 1 to 24") @@ -183,7 +181,7 @@ func main() { fmt.Printf(fmt.Sprintf("Calculating burn took %s \n", elapsed)) }) - testCase("month_of_runway = if burn > 0: cash_balance / burn else 0. Calculate burn then months of runway", func() { + skipTestCase("month_of_runway = if burn > 0: cash_balance / burn else 0. Calculate burn then months of runway", func() { num := 24 graph := incr.New() @@ -213,6 +211,62 @@ func main() { elapsed = time.Since(start) fmt.Printf(fmt.Sprintf("Calculating months of runway took %s \n", elapsed)) }) + + testCase("node amplification yields slower and slower stabilization", func() { + // w := func(bs incr.Scope, t int) incr.Incr[*int] { + // key := fmt.Sprintf("w-%d", t) + // if _, ok := cache[key]; ok { + // return incr.WithinScope(bs, cache[key]) + // } + + // r := incr.Bind(bs, incr.Var(bs, "fakeformula"), func(bs incr.Scope, formula string) incr.Incr[*int] { + // out := 1 + // return incr.Return(bs, &out) + // }) + // r.Node().SetLabel(fmt.Sprintf("w(%d)", t)) + // cache[key] = r + // return r + // } + + graph := incr.New(incr.GraphMaxRecomputeHeapHeight(1024)) + max_t := 50 + + // baseline + start := time.Now() + + observers := make([]incr.IObserver, max_t) + for i := 0; i < max_t; i++ { + o := monthsOfRunway(graph, i) + observers[i] = incr.Observe(graph, o) + } + _ = graph.Stabilize(ctx) + elapsed := time.Since(start) + fmt.Printf(fmt.Sprintf("Baseline calculation of months of runway for t= %d to %d took %s\n", 0, max_t, elapsed)) + + maxMultiplier := 10 + for k := 1; k <= maxMultiplier; k++ { + + graph = incr.New(incr.GraphMaxRecomputeHeapHeight(1024)) + + observers = make([]incr.IObserver, max_t) + + num := 5000 * k + start = time.Now() + + for i := 0; i < max_t; i++ { + o := monthsOfRunway(graph, i) + observers[i] = incr.Observe(graph, o) + } + _ = graph.Stabilize(ctx) + for i := 0; i < max_t; i++ { + observers[i].Unobserve(ctx) + } + + elapsed = time.Since(start) + fmt.Printf("Calculating months of runway for t= %d to %d took %s when prior_count(observed nodes) >%d\n", 0, max_t, elapsed, num) + fmt.Printf("Graph node count=%d, observer count=%d\n", incr.ExpertGraph(graph).NumNodes(), incr.ExpertGraph(graph).NumObservers()) + } + }) } func homedir(filename string) string { diff --git a/expert_graph.go b/expert_graph.go index 17bac98..31eb42e 100644 --- a/expert_graph.go +++ b/expert_graph.go @@ -1,7 +1,5 @@ package incr -import "context" - // ExpertGraph returns an "expert" interface to modify // internal fields of the graph type. // @@ -22,6 +20,7 @@ type IExpertGraph interface { NumNodes() uint64 NumNodesRecomputed() uint64 NumNodesChanged() uint64 + NumObservers() uint64 StabilizationNum() uint64 SetStabilizationNum(uint64) @@ -29,11 +28,8 @@ type IExpertGraph interface { RecomputeHeapLen() int RecomputeHeapIDs() []Identifier - AddObserver(IObserver) + AddObserver(IObserver) error RemoveObserver(IObserver) - - ObserveNodes(INode, ...IObserver) - UnobserveNodes(context.Context, INode, ...IObserver) } type expertGraph struct { @@ -44,6 +40,10 @@ func (eg *expertGraph) NumNodes() uint64 { return eg.graph.numNodes } +func (eg *expertGraph) NumObservers() uint64 { + return uint64(len(eg.graph.observers)) +} + func (eg *expertGraph) NumNodesRecomputed() uint64 { return eg.graph.numNodesRecomputed } @@ -85,16 +85,8 @@ func (eg *expertGraph) RecomputeHeapIDs() []Identifier { return output } -func (eg *expertGraph) ObserveNodes(n INode, observers ...IObserver) { - eg.graph.observeNodes(n, observers...) -} - -func (eg *expertGraph) UnobserveNodes(ctx context.Context, n INode, observers ...IObserver) { - eg.graph.unobserveNodes(ctx, n, observers...) -} - -func (eg *expertGraph) AddObserver(on IObserver) { - eg.graph.addObserver(on) +func (eg *expertGraph) AddObserver(on IObserver) error { + return eg.graph.addObserver(on) } func (eg *expertGraph) RemoveObserver(on IObserver) { diff --git a/expert_graph_test.go b/expert_graph_test.go index 8a8370d..ec252e7 100644 --- a/expert_graph_test.go +++ b/expert_graph_test.go @@ -1,7 +1,6 @@ package incr import ( - "context" "testing" "github.com/wcharczuk/go-incr/testutil" @@ -28,8 +27,8 @@ func Test_ExpertGraph_RecomputeHeapAdd(t *testing.T) { g := New() eg := ExpertGraph(g) - n1 := newMockBareNode() - n2 := newMockBareNode() + n1 := newMockBareNode(g) + n2 := newMockBareNode(g) eg.RecomputeHeapAdd(n1, n2) testutil.Equal(t, 2, g.recomputeHeap.len()) @@ -52,8 +51,8 @@ func Test_ExpertGraph_RecomputeHeapIDs(t *testing.T) { g := New() eg := ExpertGraph(g) - n1 := newMockBareNode() - n2 := newMockBareNode() + n1 := newMockBareNode(g) + n2 := newMockBareNode(g) n2.n.height = 3 eg.RecomputeHeapAdd(n1, n2) @@ -65,45 +64,19 @@ func Test_ExpertGraph_RecomputeHeapIDs(t *testing.T) { testutil.Any(t, recomputeHeapIDs, func(id Identifier) bool { return id == n2.n.id }) } -func Test_ExpertGraph_Observe(t *testing.T) { - g := New() - eg := ExpertGraph(g) - - n1 := newMockBareNode() - n2 := newMockBareNode() - - o1 := mockObserver() - o2 := mockObserver() - - eg.ObserveNodes(n1, o1, o2) - - testutil.Equal(t, 2, len(n1.n.observers)) - testutil.Equal(t, 0, len(n2.n.observers)) - - eg.UnobserveNodes(context.TODO(), n1, o1, o2) - - testutil.Equal(t, 0, len(n1.n.observers)) - testutil.Equal(t, 0, len(n2.n.observers)) -} - func Test_ExpertGraph_AddObserver(t *testing.T) { g := New() eg := ExpertGraph(g) - o0 := mockObserver() - o1 := mockObserver() + o0 := mockObserver(g) + o1 := mockObserver(g) - eg.AddObserver(o1) + _ = eg.AddObserver(o1) - testutil.Equal(t, false, mapHas(g.observers, o0.Node().id)) - testutil.Equal(t, true, mapHas(g.observers, o1.Node().id)) + testutil.Equal(t, false, mapHasKey(g.observers, o0.Node().id)) + testutil.Equal(t, true, mapHasKey(g.observers, o1.Node().id)) eg.RemoveObserver(o1) - testutil.Equal(t, false, mapHas(g.observers, o0.Node().id)) - testutil.Equal(t, false, mapHas(g.observers, o1.Node().id)) -} - -func mapHas[K comparable, V any](m map[K]V, k K) (ok bool) { - _, ok = m[k] - return + testutil.Equal(t, false, mapHasKey(g.observers, o0.Node().id)) + testutil.Equal(t, false, mapHasKey(g.observers, o1.Node().id)) } diff --git a/expert_node.go b/expert_node.go index 7dc0378..fcaddfb 100644 --- a/expert_node.go +++ b/expert_node.go @@ -61,7 +61,14 @@ type expertNode struct { func (en *expertNode) Graph() *Graph { return en.node.graph } func (en *expertNode) SetID(id Identifier) { + oldID := en.node.id en.node.id = id + if graph := graphFromScope(en.incr); graph != nil { + if _, ok := graph.nodes[oldID]; ok { + delete(graph.nodes, oldID) + graph.nodes[id] = en.incr + } + } } func (en *expertNode) Height() int { return en.node.height } diff --git a/expert_node_test.go b/expert_node_test.go index 17d5113..7bebd8d 100644 --- a/expert_node_test.go +++ b/expert_node_test.go @@ -7,7 +7,8 @@ import ( ) func Test_ExpertNode_setters(t *testing.T) { - n := newMockBareNode() + g := New() + n := newMockBareNode(g) en := ExpertNode(n) id := NewIdentifier() @@ -35,27 +36,32 @@ func Test_ExpertNode_setters(t *testing.T) { } func Test_ExpertNode_AddChildren(t *testing.T) { - n := newMockBareNode() + g := New() + n := newMockBareNode(g) en := ExpertNode(n) - en.AddChildren(newMockBareNode(), newMockBareNode()) + en.AddChildren(newMockBareNode(g), newMockBareNode(g)) testutil.Equal(t, 2, len(n.Node().Children())) } func Test_ExpertNode_AddParents(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) en := ExpertNode(n) - en.AddParents(newMockBareNode(), newMockBareNode()) + en.AddParents(newMockBareNode(g), newMockBareNode(g)) testutil.Equal(t, 2, len(n.Node().Parents())) } func Test_ExpertNode_RemoveChild(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) en := ExpertNode(n) - mbn0 := newMockBareNode() - mbn1 := newMockBareNode() + mbn0 := newMockBareNode(g) + mbn1 := newMockBareNode(g) en.AddChildren(mbn0, mbn1) testutil.Equal(t, 2, len(n.Node().Children())) @@ -64,11 +70,13 @@ func Test_ExpertNode_RemoveChild(t *testing.T) { } func Test_ExpertNode_RemoveParent(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) en := ExpertNode(n) - mbn0 := newMockBareNode() - mbn1 := newMockBareNode() + mbn0 := newMockBareNode(g) + mbn1 := newMockBareNode(g) en.AddParents(mbn0, mbn1) testutil.Equal(t, 2, len(n.Node().Parents())) @@ -86,21 +94,23 @@ func Test_ExpertNode_Value(t *testing.T) { } func Test_ExperNode_ComputePseudoHeight(t *testing.T) { - a00 := newMockBareNode() - a01 := newMockBareNode() - a10 := newMockBareNode() - a20 := newMockBareNode() - a21 := newMockBareNode() + g := New() + + a00 := newMockBareNode(g) + a01 := newMockBareNode(g) + a10 := newMockBareNode(g) + a20 := newMockBareNode(g) + a21 := newMockBareNode(g) Link(a10, a00, a01) Link(a20, a10) Link(a21, a10) - b00 := newMockBareNode() - b01 := newMockBareNode() - b10 := newMockBareNode() - b20 := newMockBareNode() - b21 := newMockBareNode() + b00 := newMockBareNode(g) + b01 := newMockBareNode(g) + b10 := newMockBareNode(g) + b20 := newMockBareNode(g) + b21 := newMockBareNode(g) Link(b10, b00, b01) Link(b20, b10) @@ -111,21 +121,23 @@ func Test_ExperNode_ComputePseudoHeight(t *testing.T) { } func Test_ExperNode_ComputePseudoHeight_bare(t *testing.T) { - a00 := newMockBareNode() - a01 := newMockBareNode() - a10 := newMockBareNode() - a20 := newMockBareNode() - a21 := newMockBareNode() + g := New() + + a00 := newMockBareNode(g) + a01 := newMockBareNode(g) + a10 := newMockBareNode(g) + a20 := newMockBareNode(g) + a21 := newMockBareNode(g) Link(a10, a00, a01) Link(a20, a10) Link(a21, a10) - b00 := newMockBareNode() - b01 := newMockBareNode() - b10 := newMockBareNode() - b20 := newMockBareNode() - b21 := newMockBareNode() + b00 := newMockBareNode(g) + b01 := newMockBareNode(g) + b10 := newMockBareNode(g) + b20 := newMockBareNode(g) + b21 := newMockBareNode(g) Link(b10, b00, b01) Link(b20, b10) diff --git a/fold_left.go b/fold_left.go index c649b8b..e3f9221 100644 --- a/fold_left.go +++ b/fold_left.go @@ -5,14 +5,14 @@ import "context" // FoldLeft folds an array from 0 to N carrying the previous value // to the next interation, yielding a single value. func FoldLeft[T, O any](scope Scope, i Incr[[]T], v0 O, fn func(O, T) O) Incr[O] { - o := &foldLeftIncr[T, O]{ + o := WithinScope(scope, &foldLeftIncr[T, O]{ n: NewNode("fold_left"), i: i, fn: fn, val: v0, - } + }) Link(o, i) - return WithinScope(scope, o) + return o } type foldLeftIncr[T, O any] struct { diff --git a/fold_map.go b/fold_map.go index ec8a820..e13bd88 100644 --- a/fold_map.go +++ b/fold_map.go @@ -19,14 +19,14 @@ func FoldMap[K comparable, V any, O any]( v0 O, fn func(K, V, O) O, ) Incr[O] { - o := &foldMapIncr[K, V, O]{ + o := WithinScope(scope, &foldMapIncr[K, V, O]{ n: NewNode("fold_map"), i: i, fn: fn, val: v0, - } + }) Link(o, i) - return WithinScope(scope, o) + return o } var ( @@ -51,11 +51,11 @@ func (fmi *foldMapIncr[K, V, O]) Value() O { return fmi.val } func (fmi *foldMapIncr[K, V, O]) Stabilize(_ context.Context) error { new := fmi.i.Value() - fmi.val = foldMap(new, fmi.val, fmi.fn) + fmi.val = foldMapImpl(new, fmi.val, fmi.fn) return nil } -func foldMap[K comparable, V any, O any]( +func foldMapImpl[K comparable, V any, O any]( input map[K]V, zero O, fn func(K, V, O) O, diff --git a/fold_right.go b/fold_right.go index c98d8c4..e2eb0c0 100644 --- a/fold_right.go +++ b/fold_right.go @@ -5,14 +5,14 @@ import "context" // FoldRight folds an array from N to 0 carrying the previous value // to the next interation, yielding a single value. func FoldRight[T, O any](scope Scope, i Incr[[]T], v0 O, fn func(T, O) O) Incr[O] { - o := &foldRightIncr[T, O]{ + o := WithinScope(scope, &foldRightIncr[T, O]{ n: NewNode("fold_right"), i: i, fn: fn, val: v0, - } + }) Link(o, i) - return WithinScope(scope, o) + return o } type foldRightIncr[T, O any] struct { diff --git a/freeze.go b/freeze.go index ec19eb9..327d9fc 100644 --- a/freeze.go +++ b/freeze.go @@ -8,12 +8,12 @@ import ( // Freeze yields an incremental that takes the value of an // input incremental and doesn't change thereafter. func Freeze[A any](scope Scope, i Incr[A]) Incr[A] { - o := &freezeIncr[A]{ + o := WithinScope(scope, &freezeIncr[A]{ n: NewNode("freeze"), i: i, - } + }) Link(o, i) - return WithinScope(scope, o) + return o } var ( diff --git a/func.go b/func.go index 39668f8..219702b 100644 --- a/func.go +++ b/func.go @@ -9,11 +9,12 @@ import ( // // The result of the function after the first stabilization will // be re-used between stabilizations unless you mark the node stale -// with the `SetStale` helper. +// with the `SetStale` function on the `Graph` type. // // Because there is no tracking of input changes, this node // type is generally discouraged in favor of `Map` or `Bind` -// incrementals but is included for "expert" use cases. +// incrementals but is included for "expert" use cases, typically +// as an input to other nodes. func Func[T any](scope Scope, fn func(context.Context) (T, error)) Incr[T] { return WithinScope(scope, &funcIncr[T]{ n: NewNode("func"), diff --git a/graph.go b/graph.go index fd40821..3151641 100644 --- a/graph.go +++ b/graph.go @@ -2,6 +2,7 @@ package incr import ( "context" + "fmt" "sync" "sync/atomic" "time" @@ -29,7 +30,7 @@ func New(opts ...GraphOption) *Graph { id: NewIdentifier(), stabilizationNum: 1, status: StatusNotStabilizing, - observed: make(map[Identifier]INode), + nodes: make(map[Identifier]INode), observers: make(map[Identifier]IObserver), recomputeHeap: newRecomputeHeap(options.MaxHeight), adjustHeightsHeap: newAdjustHeightsHeap(options.MaxHeight), @@ -76,10 +77,10 @@ type Graph struct { // label is a descriptive label for the graph label string - observedMu sync.Mutex + nodesMu sync.Mutex // observed are the nodes that the graph currently observes // organized by node id. - observed map[Identifier]INode + nodes map[Identifier]INode // observersMu interlocks acces to observers observersMu sync.Mutex @@ -171,10 +172,10 @@ func (graph *Graph) IsStabilizing() bool { } // IsObserving returns if a graph is observing a given node. -func (graph *Graph) IsObserving(gn INode) (ok bool) { - graph.observedMu.Lock() - _, ok = graph.observed[gn.Node().id] - graph.observedMu.Unlock() +func (graph *Graph) Has(gn INode) (ok bool) { + graph.nodesMu.Lock() + _, ok = graph.nodes[gn.Node().id] + graph.nodesMu.Unlock() return } @@ -213,77 +214,85 @@ func (graph *Graph) isRootScope() bool { return true } func (graph *Graph) scopeGraph() *Graph { return graph } +func (graph *Graph) scopeHeight() int { return -1 } + +func (graph *Graph) String() string { return fmt.Sprintf("{graph:%s}", graph.id.Short()) } + // // Internal discovery & observe methods // -// observeNodes traverses up from a given node, adding a given -// list of observers as "observing" that node, and recursing through it's inputs or parents. -func (graph *Graph) observeNodes(gn INode, observers ...IObserver) { - graph.observeSingleNode(gn, observers...) - for _, p := range gn.Node().parents { - graph.observeNodes(p, observers...) +func (graph *Graph) removeParents(child INode) { + for _, parent := range child.Node().parents { + graph.removeParent(child, parent) } } -func (graph *Graph) observeSingleNode(gn INode, observers ...IObserver) { - gnn := gn.Node() +func (graph *Graph) removeParent(child, parent INode) { + Unlink(child, parent) + graph.checkIfUnnecessary(parent) +} - gnn.addObservers(observers...) - alreadyObservedByGraph := graph.maybeAddObservedNode(gn) - if alreadyObservedByGraph { - return +func (graph *Graph) checkIfUnnecessary(n INode) { + if !graph.isNecessary(n) { + graph.removeNodeFromGraph(n) + graph.removeParents(n) } - graph.numNodes++ - gnn.graph = graph - for _, p := range gnn.parents { - _ = graph.adjustHeightsHeap.ensureHeightRequirement(gn, p) +} + +func (graph *Graph) becameNecessary(node INode) { + graph.initializeNode(node) + _ = graph.adjustHeightsHeap.setHeight(node, heightFromScope(node)+1) + for _, p := range node.Node().parents { + if p.Node().height >= node.Node().height { + _ = graph.adjustHeightsHeap.setHeight(node, p.Node().height+1) + } + graph.becameNecessary(p) } - gnn.detectCutoff(gn) - gnn.detectAlways(gn) - gnn.detectStabilize(gn) - if gnn.ShouldRecompute() { - graph.recomputeHeap.add(gn) + + if node.Node().ShouldRecompute() { + graph.recomputeHeap.add(node) } } -func (graph *Graph) maybeAddObservedNode(gn INode) (ok bool) { - graph.observedMu.Lock() - defer graph.observedMu.Unlock() - if _, ok = graph.observed[gn.Node().id]; ok { - return +func (graph *Graph) isNecessary(n INode) bool { + nn := n.Node() + if _, isObserver := n.(IObserver); isObserver { + return true } - graph.observed[gn.Node().id] = gn - return + return len(nn.children) > 0 || len(nn.observers) > 0 } -func (graph *Graph) unobserveNodes(ctx context.Context, gn INode, observers ...IObserver) { - graph.unobserveSingleNode(ctx, gn, observers...) +func (graph *Graph) initializeNode(gn INode) { gnn := gn.Node() - parents := gnn.Parents() - for _, p := range parents { - graph.unobserveNodes(ctx, p, observers...) + gnn.graph = graph + graphAlreadyHasNode := graph.maybeAddNodeToGraph(gn) + if graphAlreadyHasNode { + return } + graph.numNodes++ + gnn.detectCutoff(gn) + gnn.detectAlways(gn) + gnn.detectStabilize(gn) } -func (graph *Graph) unobserveSingleNode(ctx context.Context, gn INode, observers ...IObserver) { - remainingObserverCount := graph.removeNodeObservers(gn, observers...) - if remainingObserverCount > 0 { +func (graph *Graph) maybeAddNodeToGraph(gn INode) (ok bool) { + graph.nodesMu.Lock() + defer graph.nodesMu.Unlock() + if _, ok = graph.nodes[gn.Node().id]; ok { return } - if typed, ok := gn.(IUnobserve); ok { - typed.Unobserve(ctx, observers...) - } - graph.removeNodeFromGraph(gn) + graph.nodes[gn.Node().id] = gn + return } func (graph *Graph) removeNodeFromGraph(gn INode) { graph.recomputeHeap.remove(gn) graph.adjustHeightsHeap.remove(gn) - graph.observedMu.Lock() - delete(graph.observed, gn.Node().id) - graph.observedMu.Unlock() + graph.nodesMu.Lock() + delete(graph.nodes, gn.Node().id) + graph.nodesMu.Unlock() graph.numNodes-- @@ -296,64 +305,32 @@ func (graph *Graph) removeNodeFromGraph(gn INode) { gnn.setAt = 0 gnn.boundAt = 0 gnn.recomputedAt = 0 - gnn.createdIn = nil + + // NOTE (wc): we never _really_ can remove the createdIn reference because + // we don't track construction of nodes carefully. + // gnn.createdIn = nil gnn.graph = nil gnn.height = 0 gnn.heightInRecomputeHeap = 0 gnn.heightInAdjustHeightsHeap = 0 } -func (graph *Graph) removeNodeObservers(gn INode, observers ...IObserver) (remainingObserverCount int) { - gnn := gn.Node() - for _, on := range observers { - if graph.canReachObserver(gn, on.Node().id) { - continue - } - gnn.removeObserver(on.Node().id) - for _, handler := range gnn.onUnobservedHandlers { - handler(on) - } - } - remainingObserverCount = len(gnn.observers) - return -} - -func (graph *Graph) canReachObserver(gn INode, oid Identifier) bool { - return graph.canReachObserverRecursive(gn, gn, oid) -} - -func (graph *Graph) canReachObserverRecursive(root, gn INode, oid Identifier) bool { - for _, c := range gn.Node().children { - // if any of our children still have the observer in their observed lists - // return true immediately - if _, ok := c.Node().observerLookup[oid]; ok { - return true - } - if c.Node().id == oid { - return true - } - if graph.canReachObserverRecursive(root, c, oid) { - return true - } - } - return false -} - -func (graph *Graph) addObserver(on IObserver) { +func (graph *Graph) addObserver(on IObserver) error { onn := on.Node() onn.graph = graph graph.observersMu.Lock() if _, ok := graph.observers[onn.id]; !ok { graph.numNodes++ + graph.observers[onn.id] = on } - graph.observers[onn.id] = on graph.observersMu.Unlock() - onn.detectStabilize(on) - for _, p := range onn.parents { - _ = graph.adjustHeightsHeap.ensureHeightRequirement(on, p) + + if err := graph.adjustHeights(on); err != nil { + return err } + return nil } func (graph *Graph) removeObserver(on IObserver) { @@ -369,6 +346,31 @@ func (graph *Graph) removeObserver(on IObserver) { graph.observersMu.Lock() delete(graph.observers, onn.id) graph.observersMu.Unlock() + + graph.recomputeHeap.remove(on) + graph.adjustHeightsHeap.remove(on) + + onn.height = 0 + onn.heightInRecomputeHeap = 0 + onn.heightInAdjustHeightsHeap = 0 + onn.setAt = 0 + onn.changedAt = 0 + onn.recomputedAt = 0 +} + +func (graph *Graph) adjustHeights(node INode) error { + _ = graph.adjustHeightsHeap.setHeight(node, heightFromScope(node)+1) + for _, p := range node.Node().parents { + if p.Node().height >= node.Node().height { + _ = graph.adjustHeightsHeap.setHeight(node, p.Node().height+1) + } + } + for _, parent := range node.Node().parents { + if err := graph.adjustHeightsHeap.adjustHeights(graph.recomputeHeap, node, parent); err != nil { + return err + } + } + return nil } // @@ -448,6 +450,7 @@ func (graph *Graph) recompute(ctx context.Context, n INode) (err error) { graph.numNodesRecomputed++ nn := n.Node() nn.numRecomputes++ + nn.recomputedAt = graph.stabilizationNum var shouldCutoff bool @@ -467,8 +470,6 @@ func (graph *Graph) recompute(ctx context.Context, n INode) (err error) { graph.numNodesChanged++ nn.numChanges++ - // we have to propagate the "changed" or "recomputed" status to children - nn.changedAt = graph.stabilizationNum if err = nn.maybeStabilize(ctx); err != nil { for _, eh := range nn.onErrorHandlers { eh(ctx, err) @@ -476,6 +477,9 @@ func (graph *Graph) recompute(ctx context.Context, n INode) (err error) { return } + // we have to propagate the "changed" status to children + nn.changedAt = graph.stabilizationNum + if len(nn.onUpdateHandlers) > 0 { graph.handleAfterStabilization[nn.id] = append(graph.handleAfterStabilization[nn.id], nn.onUpdateHandlers...) } @@ -495,16 +499,3 @@ func (graph *Graph) recompute(ctx context.Context, n INode) (err error) { } return } - -func (graph *Graph) isNecessary(n INode) bool { - ng := n.Node().graph - return ng != nil && ng.id == graph.id -} - -// -// internal height management methods -// - -func (graph *Graph) recomputeHeights() error { - return graph.adjustHeightsHeap.adjustHeights(graph.recomputeHeap) -} diff --git a/graph_test.go b/graph_test.go index 106b1ad..5440166 100644 --- a/graph_test.go +++ b/graph_test.go @@ -14,12 +14,9 @@ func Test_New(t *testing.T) { m0 := Map2(g, r0, r1, func(v0, v1 string) string { return v0 + v1 }) _ = Observe(g, m0) - testutil.Equal(t, true, g.IsObserving(r0)) - testutil.Equal(t, true, g.IsObserving(r1)) - testutil.Equal(t, true, g.IsObserving(m0)) - - m1 := Map2(g, r0, r1, func(v0, v1 string) string { return v0 + v1 }) - testutil.Equal(t, false, g.IsObserving(m1)) + testutil.Equal(t, true, g.Has(r0)) + testutil.Equal(t, true, g.Has(r1)) + testutil.Equal(t, true, g.Has(m0)) } func Test_New_options(t *testing.T) { @@ -42,86 +39,6 @@ func Test_Graph_Label(t *testing.T) { testutil.Equal(t, "hello", g.Label()) } -func Test_Graph_UnobserveNodes(t *testing.T) { - ctx := testContext() - g := New() - - r0 := Return(g, "hello") - m0 := Map(g, r0, ident) - m1 := Map(g, m0, ident) - m2 := Map(g, m1, ident) - - ar0 := Return(g, "hello") - am0 := Map(g, ar0, ident) - am1 := Map(g, am0, ident) - am2 := Map(g, am1, ident) - - o1 := Observe(g, m1) - _ = Observe(g, am2) - - testutil.Equal(t, true, g.IsObserving(r0)) - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, true, g.IsObserving(m1)) - testutil.Equal(t, false, g.IsObserving(m2), "using the Observe incremental we actually don't care about m2!") - - testutil.Equal(t, true, g.IsObserving(ar0)) - testutil.Equal(t, true, g.IsObserving(am0)) - testutil.Equal(t, true, g.IsObserving(am1)) - testutil.Equal(t, true, g.IsObserving(am2)) - - Unlink(o1, m1) - g.unobserveNodes(ctx, m1, o1) - - testutil.Equal(t, false, g.IsObserving(r0)) - testutil.Equal(t, false, g.IsObserving(m0)) - testutil.Equal(t, false, g.IsObserving(m1)) - testutil.Equal(t, false, g.IsObserving(m2)) - - testutil.Nil(t, r0.Node().graph) - testutil.Nil(t, m0.Node().graph) - testutil.Nil(t, m1.Node().graph) - testutil.Nil(t, m2.Node().graph) - - testutil.Equal(t, true, g.IsObserving(ar0)) - testutil.Equal(t, true, g.IsObserving(am0)) - testutil.Equal(t, true, g.IsObserving(am1)) - testutil.Equal(t, true, g.IsObserving(am2)) -} - -func Test_Graph_UnobserveNodes_notObserving(t *testing.T) { - ctx := testContext() - g := New() - - r0 := Return(g, "hello") - m0 := Map(g, r0, ident) - m1 := Map(g, m0, ident) - m2 := Map(g, m1, ident) - - ar0 := Return(g, "hello") - am0 := Map(g, ar0, ident) - am1 := Map(g, am0, ident) - am2 := Map(g, am1, ident) - - o := Observe(g, m1) - - testutil.Equal(t, true, g.IsObserving(r0)) - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, true, g.IsObserving(m1)) - testutil.Equal(t, false, g.IsObserving(m2), "we observed m1, which is the parent of m2!") - - testutil.Equal(t, false, g.IsObserving(ar0)) - testutil.Equal(t, false, g.IsObserving(am0)) - testutil.Equal(t, false, g.IsObserving(am1)) - testutil.Equal(t, false, g.IsObserving(am2)) - - g.unobserveNodes(ctx, am1, o) - - testutil.Equal(t, true, g.IsObserving(r0)) - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, true, g.IsObserving(m1)) - testutil.Equal(t, false, g.IsObserving(m2)) -} - func Test_Graph_IsStabilizing(t *testing.T) { g := New() testutil.Equal(t, false, g.IsStabilizing()) @@ -144,7 +61,7 @@ func Test_Graph_addObserver_rediscover(t *testing.T) { g.recomputeHeap.remove(o) testutil.Equal(t, false, g.recomputeHeap.has(o)) - g.addObserver(o) + _ = g.addObserver(o) testutil.Equal(t, 2, g.numNodes) testutil.Equal(t, 1, o.Node().height) testutil.Equal(t, false, g.recomputeHeap.has(o)) @@ -152,7 +69,7 @@ func Test_Graph_addObserver_rediscover(t *testing.T) { func Test_Graph_recompute_recomputesObservers(t *testing.T) { g := New() - n := newMockBareNode() + n := newMockBareNode(g) o := Observe(g, n) g.recomputeHeap.Clear() @@ -171,7 +88,7 @@ func Test_Graph_removeNodeFromGraph(t *testing.T) { mn00 := newMockBareNodeWithHeight(2) g.numNodes = 2 - g.observed[mn00.n.id] = mn00 + g.nodes[mn00.n.id] = mn00 g.handleAfterStabilization[mn00.n.id] = []func(context.Context){ func(_ context.Context) {}, diff --git a/identifier.go b/identifier.go index d1e52fc..ce035c7 100644 --- a/identifier.go +++ b/identifier.go @@ -21,6 +21,17 @@ func NewIdentifier() (output Identifier) { return } +// MustParseIdentifier is the reverse of `.String()` that will +// panic if an error is returned by `ParseIdentifier`. +func MustParseIdentifier(raw string) (output Identifier) { + var err error + output, err = ParseIdentifier(raw) + if err != nil { + panic(err) + } + return +} + // ParseIdentifier is the reverse of `.String()`. func ParseIdentifier(raw string) (output Identifier, err error) { if raw == "" { diff --git a/incrutil/diff_maps.go b/incrutil/diff_maps.go index 9f01db6..8498f1b 100644 --- a/incrutil/diff_maps.go +++ b/incrutil/diff_maps.go @@ -11,18 +11,16 @@ import ( // for keys removed, and each stabilization pass returns just the subset // of the map that changed since the last pass according to the keys. func DiffMapByKeys[K comparable, V any](scope incr.Scope, i incr.Incr[map[K]V]) (add incr.Incr[map[K]V], rem incr.Incr[map[K]V]) { - add = &diffMapByKeysAddedIncr[K, V]{ + add = incr.WithinScope(scope, &diffMapByKeysAddedIncr[K, V]{ n: incr.NewNode("diff_maps_by_keys_added"), i: i, - } + }) incr.Link(add, i) - add = incr.WithinScope(scope, add) - rem = &diffMapByKeysRemovedIncr[K, V]{ + rem = incr.WithinScope(scope, &diffMapByKeysRemovedIncr[K, V]{ n: incr.NewNode("diff_maps_by_keys_removed"), i: i, - } + }) incr.Link(rem, i) - rem = incr.WithinScope(scope, rem) return } @@ -30,24 +28,24 @@ func DiffMapByKeys[K comparable, V any](scope incr.Scope, i incr.Incr[map[K]V]) // incremental, and each stabilization pass returns just the subset // of the map that was added since the last pass according to the keys. func DiffMapByKeysAdded[K comparable, V any](scope incr.Scope, i incr.Incr[map[K]V]) incr.Incr[map[K]V] { - o := &diffMapByKeysAddedIncr[K, V]{ + o := incr.WithinScope(scope, &diffMapByKeysAddedIncr[K, V]{ n: incr.NewNode("diff_maps_by_keys_added"), i: i, - } + }) incr.Link(o, i) - return incr.WithinScope(scope, o) + return o } // DiffMapByKeysRemoved returns an incremental that takes an input map typed // incremental, and each stabilization pass returns just the subset // of the map that was removed since the last pass according to the keys. func DiffMapByKeysRemoved[K comparable, V any](scope incr.Scope, i incr.Incr[map[K]V]) incr.Incr[map[K]V] { - o := &diffMapByKeysRemovedIncr[K, V]{ + o := incr.WithinScope(scope, &diffMapByKeysRemovedIncr[K, V]{ n: incr.NewNode("diff_maps_by_keys_removed"), i: i, - } + }) incr.Link(o, i) - return incr.WithinScope(scope, o) + return o } var ( diff --git a/incrutil/diff_slice.go b/incrutil/diff_slice.go index b8733fe..d2bccc9 100644 --- a/incrutil/diff_slice.go +++ b/incrutil/diff_slice.go @@ -9,12 +9,12 @@ import ( // DiffSliceByIndicesAdded diffs a slice between stabilizations, yielding an // incremental that is just the added elements per pass. func DiffSliceByIndicesAdded[T any](scope incr.Scope, i incr.Incr[[]T]) incr.Incr[[]T] { - o := &diffSliceByIndicesAddedIncr[T]{ + o := incr.WithinScope(scope, &diffSliceByIndicesAddedIncr[T]{ n: incr.NewNode("diff_slice_by_indices_added"), i: i, - } + }) incr.Link(o, i) - return incr.WithinScope(scope, o) + return o } type diffSliceByIndicesAddedIncr[T any] struct { diff --git a/link.go b/link.go index f39ab98..0306a86 100644 --- a/link.go +++ b/link.go @@ -8,8 +8,24 @@ package incr // An error is returned if the provided inputs to the child node // would produce a cycle. func Link(child INode, parents ...INode) { + graph := graphFromScope(child) + wasNecessary := graph.isNecessary(child) child.Node().addParents(parents...) for _, parent := range parents { parent.Node().addChildren(child) } + if !wasNecessary { + graph.becameNecessary(child) + } + for _, parent := range parents { + _ = graph.adjustHeightsHeap.adjustHeights(graph.recomputeHeap, child, parent) + } +} + +func graphFromScope(n INode) *Graph { + return n.Node().createdIn.scopeGraph() +} + +func heightFromScope(n INode) int { + return n.Node().createdIn.scopeHeight() } diff --git a/main_test.go b/main_test.go index 65ade2e..6e79cee 100644 --- a/main_test.go +++ b/main_test.go @@ -89,10 +89,10 @@ func identMany[T any](v ...T) (out T) { var _ Incr[any] = (*mockBareNode)(nil) -func mockObserver() IObserver { - return &observeIncr[any]{ +func mockObserver(scope Scope) IObserver { + return WithinScope(scope, &observeIncr[any]{ n: NewNode("mock_observer"), - } + }) } func newMockBareNodeWithHeight(height int) *mockBareNode { @@ -103,10 +103,10 @@ func newMockBareNodeWithHeight(height int) *mockBareNode { return mbn } -func newMockBareNode() *mockBareNode { - return &mockBareNode{ +func newMockBareNode(scope Scope) *mockBareNode { + return WithinScope(scope, &mockBareNode{ n: NewNode("bare_node"), - } + }) } type mockBareNode struct { @@ -121,23 +121,23 @@ func (mn *mockBareNode) Value() any { return nil } -func newHeightIncr(height int) *heightIncr { - return &heightIncr{ +func newHeightIncr(scope Scope, height int) *heightIncr { + return WithinScope(scope, &heightIncr{ n: &Node{ id: NewIdentifier(), height: height, }, - } + }) } -func newHeightIncrLabel(height int, label string) *heightIncr { - return &heightIncr{ +func newHeightIncrLabel(scope Scope, height int, label string) *heightIncr { + return WithinScope(scope, &heightIncr{ n: &Node{ id: NewIdentifier(), height: height, label: label, }, - } + }) } type heightIncr struct { diff --git a/map.go b/map.go index a653e9d..27cb63e 100644 --- a/map.go +++ b/map.go @@ -17,13 +17,13 @@ func Map[A, B any](scope Scope, a Incr[A], fn func(A) B) Incr[B] { // a new incremental of the output type of that function but is context aware // and can also return an error, aborting stabilization. func MapContext[A, B any](scope Scope, a Incr[A], fn func(context.Context, A) (B, error)) Incr[B] { - m := &mapIncr[A, B]{ + m := WithinScope(scope, &mapIncr[A, B]{ n: NewNode("map"), a: a, fn: fn, - } + }) Link(m, a) - return WithinScope(scope, m) + return m } var ( diff --git a/map2.go b/map2.go index 2ee831c..8b5783b 100644 --- a/map2.go +++ b/map2.go @@ -17,14 +17,14 @@ func Map2[A, B, C any](scope Scope, a Incr[A], b Incr[B], fn func(A, B) C) Incr[ // to a given input incremental and returns a new incremental of // the output type of that function. func Map2Context[A, B, C any](scope Scope, a Incr[A], b Incr[B], fn func(context.Context, A, B) (C, error)) Incr[C] { - o := &map2Incr[A, B, C]{ + o := WithinScope(scope, &map2Incr[A, B, C]{ n: NewNode("map2"), a: a, b: b, fn: fn, - } + }) Link(o, a, b) - return WithinScope(scope, o) + return o } var ( diff --git a/map3.go b/map3.go index 7831df6..9ba8ca6 100644 --- a/map3.go +++ b/map3.go @@ -17,15 +17,15 @@ func Map3[A, B, C, D any](scope Scope, a Incr[A], b Incr[B], c Incr[C], fn func( // an error, to given input incrementals and returns a // new incremental of the output type of that function. func Map3Context[A, B, C, D any](scope Scope, a Incr[A], b Incr[B], c Incr[C], fn func(context.Context, A, B, C) (D, error)) Incr[D] { - o := &map3Incr[A, B, C, D]{ + o := WithinScope(scope, &map3Incr[A, B, C, D]{ n: NewNode("map3"), a: a, b: b, c: c, fn: fn, - } + }) Link(o, a, b, c) - return WithinScope(scope, o) + return o } var ( diff --git a/map4.go b/map4.go index 880bc15..28a6aac 100644 --- a/map4.go +++ b/map4.go @@ -17,16 +17,16 @@ func Map4[A, B, C, D, E any](scope Scope, a Incr[A], b Incr[B], c Incr[C], d Inc // an error, to given input incrementals and returns a // new incremental of the output type of that function. func Map4Context[A, B, C, D, E any](scope Scope, a Incr[A], b Incr[B], c Incr[C], d Incr[D], fn func(context.Context, A, B, C, D) (E, error)) Incr[E] { - o := &map4Incr[A, B, C, D, E]{ + o := WithinScope(scope, &map4Incr[A, B, C, D, E]{ n: NewNode("map4"), a: a, b: b, c: c, d: d, fn: fn, - } + }) Link(o, a, b, c, d) - return WithinScope(scope, o) + return o } var ( diff --git a/map_if.go b/map_if.go index 2b55758..f1057dd 100644 --- a/map_if.go +++ b/map_if.go @@ -11,14 +11,14 @@ import ( // Specifically, we term this _Apply_If because the nodes are all // linked in the graph, but the value changes during stabilization. func MapIf[A any](scope Scope, a, b Incr[A], p Incr[bool]) Incr[A] { - o := &mapIfIncr[A]{ + o := WithinScope(scope, &mapIfIncr[A]{ n: NewNode("map_if"), a: a, b: b, p: p, - } + }) Link(o, a, b, p) - return WithinScope(scope, o) + return o } var ( diff --git a/map_n.go b/map_n.go index 5ed7f4b..c2ca4ed 100644 --- a/map_n.go +++ b/map_n.go @@ -16,15 +16,15 @@ func MapN[A, B any](scope Scope, fn MapNFunc[A, B], inputs ...Incr[A]) MapNIncr[ // MapNContext applies a function to given list of input incrementals and returns // a new incremental of the output type of that function. func MapNContext[A, B any](scope Scope, fn MapNContextFunc[A, B], inputs ...Incr[A]) MapNIncr[A, B] { - o := &mapNIncr[A, B]{ + o := WithinScope(scope, &mapNIncr[A, B]{ n: NewNode("map_n"), inputs: inputs, fn: fn, - } + }) for _, i := range inputs { Link(o, i) } - return WithinScope(scope, o) + return o } // MapNFunc is the function that the MapN incremental applies. diff --git a/node.go b/node.go index 1e4af55..57d0f02 100644 --- a/node.go +++ b/node.go @@ -76,12 +76,6 @@ type Node struct { // onErrorHandlers are functions that are called when the node updates. // they are added with `OnUpdate(...)`. onErrorHandlers []func(context.Context, error) - // onObservedHandlers are functions that are called when the node is observed. - // they are added with `OnObserved(...)`. - onObservedHandlers []func(IObserver) - // onUnobservedHandlers are functions that are called when the node is unobserved. - // they are added with `OnUnobserved(...)`. - onUnobservedHandlers []func(IObserver) // stabilize is set during initialization and is a shortcut // to the interface sniff for the node for the IStabilize interface. stabilize func(context.Context) error @@ -148,16 +142,6 @@ func (n *Node) OnError(fn func(context.Context, error)) { n.onErrorHandlers = append(n.onErrorHandlers, fn) } -// OnObserved registers an observed handler. -func (n *Node) OnObserved(fn func(IObserver)) { - n.onObservedHandlers = append(n.onObservedHandlers, fn) -} - -// OnUnobserved registers an unobserved handler. -func (n *Node) OnUnobserved(fn func(IObserver)) { - n.onUnobservedHandlers = append(n.onUnobservedHandlers, fn) -} - // Label returns a descriptive label for the node or // an empty string if one hasn't been provided. func (n *Node) Label() string { @@ -241,9 +225,6 @@ func (n *Node) addObservers(observers ...IObserver) { if !n.observerLookup.has(o.Node().id) { n.observers = append(n.observers, o) n.observerLookup.add(o.Node().id) - for _, handler := range n.onObservedHandlers { - handler(o) - } } } } diff --git a/node_list.go b/node_list.go index f111abb..c5ac573 100644 --- a/node_list.go +++ b/node_list.go @@ -18,3 +18,8 @@ func hasKey[A INode](nodes []A, id Identifier) bool { } return false } + +func mapHasKey[K comparable, V any](m map[K]V, k K) (ok bool) { + _, ok = m[k] + return +} diff --git a/node_list_test.go b/node_list_test.go index 9bd9524..5253afd 100644 --- a/node_list_test.go +++ b/node_list_test.go @@ -7,9 +7,11 @@ import ( ) func Test_remove(t *testing.T) { - n0 := newMockBareNode() - n1 := newMockBareNode() - n2 := newMockBareNode() + g := New() + + n0 := newMockBareNode(g) + n1 := newMockBareNode(g) + n2 := newMockBareNode(g) nodes := []INode{ n0, n1, n2, } diff --git a/node_test.go b/node_test.go index 9ee84da..6f2cb4b 100644 --- a/node_test.go +++ b/node_test.go @@ -53,10 +53,12 @@ func Test_Node_Metadata(t *testing.T) { } func Test_Link(t *testing.T) { - c := newMockBareNode() - p0 := newMockBareNode() - p1 := newMockBareNode() - p2 := newMockBareNode() + g := New() + + c := newMockBareNode(g) + p0 := newMockBareNode(g) + p1 := newMockBareNode(g) + p2 := newMockBareNode(g) // set up P with (3) inputs Link(c, p0, p1, p2) @@ -83,7 +85,9 @@ func Test_Link(t *testing.T) { } func Test_Node_String(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) n.n.height = 2 testutil.Equal(t, "bare_node["+n.n.id.Short()+"]@2", n.Node().String()) @@ -94,7 +98,7 @@ func Test_Node_String(t *testing.T) { func Test_SetStale(t *testing.T) { g := New() - n := newMockBareNode() + n := newMockBareNode(g) _ = Observe(g, n) g.SetStale(n) @@ -143,13 +147,15 @@ func Test_Node_SetLabel(t *testing.T) { } func Test_Node_addChildren(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) _ = n.Node() - c0 := newMockBareNode() + c0 := newMockBareNode(g) _ = c0.Node() - c1 := newMockBareNode() + c1 := newMockBareNode(g) _ = c1.Node() testutil.Equal(t, 0, len(n.n.parents)) @@ -165,16 +171,18 @@ func Test_Node_addChildren(t *testing.T) { } func Test_Node_removeChild(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) _ = n.Node() - c0 := newMockBareNode() + c0 := newMockBareNode(g) _ = c0.Node() - c1 := newMockBareNode() + c1 := newMockBareNode(g) _ = c1.Node() - c2 := newMockBareNode() + c2 := newMockBareNode(g) _ = c2.Node() n.Node().addChildren(c0, c1, c2) @@ -196,13 +204,15 @@ func Test_Node_removeChild(t *testing.T) { } func Test_Node_addParents(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) _ = n.Node() - c0 := newMockBareNode() + c0 := newMockBareNode(g) _ = c0.Node() - c1 := newMockBareNode() + c1 := newMockBareNode(g) _ = c1.Node() testutil.Equal(t, 0, len(n.n.parents)) @@ -218,16 +228,18 @@ func Test_Node_addParents(t *testing.T) { } func Test_Node_removeParent(t *testing.T) { - n := newMockBareNode() + g := New() + + n := newMockBareNode(g) _ = n.Node() - c0 := newMockBareNode() + c0 := newMockBareNode(g) _ = c0.Node() - c1 := newMockBareNode() + c1 := newMockBareNode(g) _ = c1.Node() - c2 := newMockBareNode() + c2 := newMockBareNode(g) _ = c2.Node() n.Node().addParents(c0, c1, c2) @@ -332,6 +344,8 @@ func Test_Node_detectStabilize(t *testing.T) { } func Test_Node_shouldRecompute(t *testing.T) { + g := New() + n := NewNode("test_node") testutil.Equal(t, true, n.ShouldRecompute()) @@ -347,10 +361,10 @@ func Test_Node_shouldRecompute(t *testing.T) { testutil.Equal(t, true, n.ShouldRecompute()) n.changedAt = 1 - c1 := newMockBareNode() + c1 := newMockBareNode(g) c1.Node().changedAt = 2 - n.addParents(newMockBareNode(), c1) + n.addParents(newMockBareNode(g), c1) testutil.Equal(t, true, n.ShouldRecompute()) c1.Node().changedAt = 1 @@ -368,7 +382,7 @@ func Test_Node_recompute(t *testing.T) { return "hello", nil }) - p := newMockBareNode() + p := newMockBareNode(g) m0.Node().addParents(p) _ = Observe(g, m0) @@ -428,7 +442,7 @@ func Test_Node_stabilize_error(t *testing.T) { return "", fmt.Errorf("test error") }) - p := newMockBareNode() + p := newMockBareNode(g) m0.Node().addParents(p) _ = Observe(g, m0) @@ -505,18 +519,20 @@ func Test_nodeFormatters(t *testing.T) { } func Test_Node_Properties_readonly(t *testing.T) { + g := New() + n := &Node{ height: 1, setAt: 2, changedAt: 3, children: []INode{ - newMockBareNode(), - newMockBareNode(), + newMockBareNode(g), + newMockBareNode(g), }, parents: []INode{ - newMockBareNode(), - newMockBareNode(), - newMockBareNode(), + newMockBareNode(g), + newMockBareNode(g), + newMockBareNode(g), }, } @@ -553,23 +569,25 @@ func Test_Node_Observers(t *testing.T) { n := &Node{ observers: []IObserver{one, two}, } - testutil.Equal(t, 2, len(n.Observers())) + testutil.Equal(t, 2, len(n.observers)) } func Test_nodeSorter(t *testing.T) { - a := newMockBareNode() + g := New() + + a := newMockBareNode(g) a.Node().height = 1 a.Node().id, _ = ParseIdentifier(strings.Repeat("0", 32)) - b := newMockBareNode() + b := newMockBareNode(g) b.Node().height = 1 b.Node().id, _ = ParseIdentifier(strings.Repeat("1", 32)) - c := newMockBareNode() + c := newMockBareNode(g) c.Node().height = 1 c.Node().id, _ = ParseIdentifier(strings.Repeat("2", 32)) - d := newMockBareNode() + d := newMockBareNode(g) d.Node().height = 2 d.Node().id, _ = ParseIdentifier(strings.Repeat("3", 32)) diff --git a/observe.go b/observe.go index 818cb5e..06dd0fd 100644 --- a/observe.go +++ b/observe.go @@ -8,22 +8,22 @@ import ( // Observe observes a node, specifically including it for computation // as well as all of its parents. func Observe[A any](g *Graph, input Incr[A]) ObserveIncr[A] { - o := &observeIncr[A]{ + o := WithinScope(g, &observeIncr[A]{ n: NewNode("observer"), input: input, - } + }) Link(o, input) - g.addObserver(o) - // NOTE(wc): we do this here because some """expert""" use cases for `ExpertGraph::DiscoverObserver` + // NOTE(wc): we do this here because some """expert""" use cases for `ExpertGraph::AddObserver` // require us to add the observer to the graph observer list but _not_ // add it to the recompute heap. // // So we just add it here explicitly and don't add it implicitly - // in the DiscoverObserver function. + // in the AddObserver function. + _ = g.addObserver(o) + input.Node().addObservers(o) + g.becameNecessary(input) g.recomputeHeap.add(o) - g.observeNodes(input, o) - _ = g.recomputeHeights() - return WithinScope(g, o) + return o } // ObserveIncr is an incremental that observes a graph @@ -75,16 +75,9 @@ func (o *observeIncr[A]) Stabilize(_ context.Context) error { func (o *observeIncr[A]) Unobserve(ctx context.Context) { g := o.n.graph - for _, p := range o.n.parents { - Unlink(o, p) - } - - g.unobserveNodes(ctx, o.input, o) + g.removeParents(o) g.removeObserver(o) - o.n.children = nil - o.n.parents = nil - // zero out the observed value var value A o.value = value diff --git a/observe_test.go b/observe_test.go index 56f9021..b8a407a 100644 --- a/observe_test.go +++ b/observe_test.go @@ -1,139 +1 @@ package incr - -import ( - "context" - "testing" - - "github.com/wcharczuk/go-incr/testutil" -) - -func Test_Observe_Unobserve(t *testing.T) { - ctx := testContext() - g := New() - - v0 := Var(g, "hello 0") - m0 := Map(g, v0, ident) - - v1 := Var(g, "hello 1") - m1 := Map(g, v1, ident) - - o0 := Observe(g, m0) - o1 := Observe(g, m1) - - testutil.Equal(t, 6, g.numNodes) - - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, true, g.IsObserving(m1)) - - testutil.Equal(t, "", o0.Value()) - testutil.Equal(t, "", o1.Value()) - - err := g.Stabilize(context.TODO()) - testutil.Nil(t, err) - - testutil.Equal(t, "hello 0", o0.Value()) - testutil.Equal(t, "hello 1", o1.Value()) - - o1.Unobserve(ctx) - - testutil.Equal(t, len(g.observed), g.numNodes-1, "we don't observe the observer but we do track it!") - testutil.Nil(t, o1.Node().graph) - - // should take effect immediately because there is only (1) observer. - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, false, g.IsObserving(m1)) - - v0.Set("not hello 0") - v1.Set("not hello 1") - err = g.Stabilize(context.TODO()) - testutil.Nil(t, err) - - testutil.Equal(t, "not hello 0", o0.Value()) - testutil.Equal(t, "", o1.Value()) -} - -func Test_Observe_Unobserve_multiple(t *testing.T) { - ctx := testContext() - g := New() - - v0 := Var(g, "hello 0") - m0 := Map(g, v0, ident) - - v1 := Var(g, "hello 1") - m1 := Map(g, v1, ident) - - o0 := Observe(g, m0) - o1 := Observe(g, m1) - o11 := Observe(g, m1) - - testutil.Equal(t, true, g.IsObserving(v0)) - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, true, g.IsObserving(v1)) - testutil.Equal(t, true, g.IsObserving(m1)) - - testutil.Equal(t, 1, len(v0.Node().Observers())) - testutil.Equal(t, 1, len(m0.Node().Observers())) - testutil.Equal(t, 2, len(v1.Node().Observers())) - testutil.Equal(t, 2, len(m1.Node().Observers())) - - testutil.Equal(t, "", o0.Value()) - testutil.Equal(t, "", o1.Value()) - testutil.Equal(t, "", o11.Value()) - - err := g.Stabilize(context.TODO()) - testutil.Nil(t, err) - - testutil.Equal(t, "hello 0", o0.Value()) - testutil.Equal(t, "hello 1", o1.Value()) - testutil.Equal(t, "hello 1", o11.Value()) - - o1.Unobserve(ctx) - - testutil.Equal(t, len(g.observed), g.numNodes-2, "we should have (1) less observer after unobserve!") - testutil.Nil(t, o1.Node().graph) - - testutil.Equal(t, 0, len(o1.Node().parents)) - testutil.Equal(t, 0, len(o1.Node().children)) - testutil.None(t, m1.Node().Children(), func(n INode) bool { - return n.Node().ID() == o1.Node().ID() - }) - - testutil.Equal(t, true, g.IsObserving(m0)) - testutil.Equal(t, true, g.IsObserving(m1)) - - testutil.Equal(t, 1, len(v0.Node().Observers())) - testutil.Equal(t, 1, len(m0.Node().Observers())) - testutil.Equal(t, 1, len(v1.Node().Observers())) - testutil.Equal(t, 1, len(m1.Node().Observers())) - - v0.Set("not hello 0") - v1.Set("not hello 1") - err = g.Stabilize(ctx) - testutil.Nil(t, err) - - testutil.Equal(t, "not hello 0", o0.Value()) - testutil.Equal(t, "", o1.Value()) - testutil.Equal(t, "not hello 1", o11.Value()) -} - -func Test_Observer_Unobserve_reobserve(t *testing.T) { - ctx := testContext() - g := New() - v0 := Var(g, "hello") - m0 := Map(g, v0, ident) - o0 := Observe(g, m0) - - _ = g.Stabilize(context.TODO()) - testutil.Equal(t, "hello", o0.Value()) - - o0.Unobserve(ctx) - - _ = g.Stabilize(context.TODO()) - testutil.Equal(t, false, g.IsObserving(m0)) - // strictly, the value shouldn't change ... - testutil.Equal(t, "hello", m0.Value()) - - o1 := Observe(g, m0) - _ = g.Stabilize(context.TODO()) - testutil.Equal(t, "hello", o1.Value()) -} diff --git a/observer_overhaul_notes.md b/observer_overhaul_notes.md new file mode 100644 index 0000000..a1ce2e8 --- /dev/null +++ b/observer_overhaul_notes.md @@ -0,0 +1,27 @@ +Observer Overhaul Notes +======================= + +- The current implementation of observers is overly aggressive; we almost never really care about parts of a graph that are linked but unobserved (can this even happen?) +- We can, as a result, do away with the hierarchical concept of observers and just use them to anchor points in the graph. + +Spedific changes +- Do away with recursive observed steps. +- Observation simply changes the observer(s) list for a node. +- Necessary is a combination of having children and being in a scope. +- When we unobserve, we snap the parent list of the observer and remove it from the observed nodes observer list. +- This does _not_ affect the graph meaningfully. If we swap out a bind graph, we have to remove it from the graph during the scope update. + +A node is necessary if: +- it has children +OR +- it has observers + +What _is_ the purpose of observers then? +- just to anchor leaves of the graph + +What happens if a node becomes unnecessary? +- if we leave it with no parents, and it's unobserved, it's removed from the graph. +- this happens at _unlink_ time +- we then proceed up (or down?) the graph, removing all parents of that newly unnecessary node + +Tests will be a pain to change but that's fine. \ No newline at end of file diff --git a/parallel_stabilize.go b/parallel_stabilize.go index 33ad45b..ea275c9 100644 --- a/parallel_stabilize.go +++ b/parallel_stabilize.go @@ -2,8 +2,6 @@ package incr import ( "context" - "fmt" - "runtime" ) // ParallelStabilize stabilizes graphs in parallel as entered @@ -39,7 +37,7 @@ func (graph *Graph) parallelStabilize(ctx context.Context) (err error) { return } workerPool := new(parallelBatch) - workerPool.SetLimit(runtime.NumCPU()) + workerPool.SetLimit(-1) //runtime.NumCPU()) var immediateRecompute []INode var minHeightBlock []INode @@ -54,9 +52,6 @@ func (graph *Graph) parallelStabilize(ctx context.Context) (err error) { if err = workerPool.Wait(); err != nil { break } - if err = graph.recomputeHeights(); err != nil { - break - } } graph.recomputeHeap.add(immediateRecompute...) return @@ -64,11 +59,6 @@ func (graph *Graph) parallelStabilize(ctx context.Context) (err error) { func (graph *Graph) parallelRecomputeNode(ctx context.Context, n INode) func() error { return func() (err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic stabilizing %v: %+v", n, r) - } - }() err = graph.recompute(ctx, n) return } diff --git a/parallel_stabilize_test.go b/parallel_stabilize_test.go index 7373706..77257bc 100644 --- a/parallel_stabilize_test.go +++ b/parallel_stabilize_test.go @@ -272,18 +272,6 @@ func Test_ParallelStabilize_always_cutoff_error(t *testing.T) { testutil.Equal(t, 3, g.recomputeHeap.len()) } -func Test_ParallelStabilize_recoversPanics(t *testing.T) { - g := New() - - v0 := Var(g, "hello") - gonnaPanic := Map(g, v0, func(_ string) string { - panic("help!") - }) - _ = Observe(g, gonnaPanic) - err := g.ParallelStabilize(testContext()) - testutil.NotNil(t, err) -} - func Test_ParallelStabilize_printsErrors(t *testing.T) { ctx := context.Background() g := New() diff --git a/recompute_heap_test.go b/recompute_heap_test.go index 000a81a..406435f 100644 --- a/recompute_heap_test.go +++ b/recompute_heap_test.go @@ -7,11 +7,13 @@ import ( ) func Test_recomputeHeap_add(t *testing.T) { + g := New() + rh := newRecomputeHeap(32) - n50 := newHeightIncr(5) - n60 := newHeightIncr(6) - n70 := newHeightIncr(7) + n50 := newHeightIncr(g, 5) + n60 := newHeightIncr(g, 6) + n70 := newHeightIncr(g, 7) rh.add(n50) @@ -57,22 +59,24 @@ func Test_recomputeHeap_add(t *testing.T) { } func Test_recomputeHeap_removeMinHeight(t *testing.T) { + g := New() + rh := newRecomputeHeap(10) - n00 := newHeightIncr(0) - n01 := newHeightIncr(0) - n02 := newHeightIncr(0) + n00 := newHeightIncr(g, 0) + n01 := newHeightIncr(g, 0) + n02 := newHeightIncr(g, 0) - n10 := newHeightIncr(1) - n11 := newHeightIncr(1) - n12 := newHeightIncr(1) - n13 := newHeightIncr(1) + n10 := newHeightIncr(g, 1) + n11 := newHeightIncr(g, 1) + n12 := newHeightIncr(g, 1) + n13 := newHeightIncr(g, 1) - n50 := newHeightIncr(5) - n51 := newHeightIncr(5) - n52 := newHeightIncr(5) - n53 := newHeightIncr(5) - n54 := newHeightIncr(5) + n50 := newHeightIncr(g, 5) + n51 := newHeightIncr(g, 5) + n52 := newHeightIncr(g, 5) + n53 := newHeightIncr(g, 5) + n54 := newHeightIncr(g, 5) rh.add(n00) rh.add(n01) @@ -139,13 +143,14 @@ func Test_recomputeHeap_removeMinHeight(t *testing.T) { } func Test_recomputeHeap_remove(t *testing.T) { + g := New() rh := newRecomputeHeap(10) - n10 := newHeightIncr(1) - n11 := newHeightIncr(1) - n20 := newHeightIncr(2) - n21 := newHeightIncr(2) - n22 := newHeightIncr(2) - n30 := newHeightIncr(3) + n10 := newHeightIncr(g, 1) + n11 := newHeightIncr(g, 1) + n20 := newHeightIncr(g, 2) + n21 := newHeightIncr(g, 2) + n22 := newHeightIncr(g, 2) + n30 := newHeightIncr(g, 3) // this should just return rh.remove(n10) @@ -233,16 +238,17 @@ func Test_recomputeHeap_maybeAddNewHeights(t *testing.T) { } func Test_recomputeHeap_add_adjustsHeights(t *testing.T) { + g := New() rh := newRecomputeHeap(8) Equal(t, 8, len(rh.heights)) - v0 := newHeightIncr(32) + v0 := newHeightIncr(g, 32) rh.add(v0) Equal(t, 33, len(rh.heights)) Equal(t, 32, rh.minHeight) Equal(t, 32, rh.maxHeight) - v1 := newHeightIncr(64) + v1 := newHeightIncr(g, 64) rh.add(v1) Equal(t, 65, len(rh.heights)) Equal(t, 32, rh.minHeight) @@ -250,38 +256,40 @@ func Test_recomputeHeap_add_adjustsHeights(t *testing.T) { } func Test_recomputeHeap_add_regression2(t *testing.T) { + g := New() + // another real world use case! also insane! rh := newRecomputeHeap(256) - observer4945d288 := newHeightIncr(1) + observer4945d288 := newHeightIncr(g, 1) rh.add(observer4945d288) - observer87df48be := newHeightIncr(1) + observer87df48be := newHeightIncr(g, 1) rh.add(observer87df48be) - mapf2cb6e46 := newHeightIncr(0) + mapf2cb6e46 := newHeightIncr(g, 0) rh.add(mapf2cb6e46) - map26e9bfb2a := newHeightIncr(0) + map26e9bfb2a := newHeightIncr(g, 0) rh.add(map26e9bfb2a) map26e9bfb2a.n.height = 2 rh.add(map26e9bfb2a) - map2dfe7c676 := newHeightIncr(1) + map2dfe7c676 := newHeightIncr(g, 1) rh.add(map2dfe7c676) - map2aa9d55f9 := newHeightIncr(1) + map2aa9d55f9 := newHeightIncr(g, 1) rh.add(map2aa9d55f9) - observerbaad6dd3 := newHeightIncr(1) + observerbaad6dd3 := newHeightIncr(g, 1) rh.add(observerbaad6dd3) - map2aa3f9a14 := newHeightIncr(1) + map2aa3f9a14 := newHeightIncr(g, 1) rh.add(map2aa3f9a14) - observer6e9e8864 := newHeightIncr(1) + observer6e9e8864 := newHeightIncr(g, 1) rh.add(observer6e9e8864) - varb35bfa8a := newHeightIncr(1) + varb35bfa8a := newHeightIncr(g, 1) rh.add(varb35bfa8a) - var54b93408 := newHeightIncr(1) + var54b93408 := newHeightIncr(g, 1) rh.add(var54b93408) - alwaysc83986c6 := newHeightIncr(0) + alwaysc83986c6 := newHeightIncr(g, 0) rh.add(alwaysc83986c6) - cutoff9d454a57 := newHeightIncr(0) + cutoff9d454a57 := newHeightIncr(g, 0) rh.add(cutoff9d454a57) - varc0898518 := newHeightIncr(1) + varc0898518 := newHeightIncr(g, 1) rh.add(varc0898518) cutoff9d454a57.n.height = 4 rh.add(cutoff9d454a57) @@ -313,12 +321,14 @@ func Test_recomputeHeap_add_regression2(t *testing.T) { } func Test_recomputeHeap_fix(t *testing.T) { + g := New() + rh := newRecomputeHeap(8) - v0 := newHeightIncr(2) + v0 := newHeightIncr(g, 2) rh.add(v0) - v1 := newHeightIncr(3) + v1 := newHeightIncr(g, 3) rh.add(v1) - v2 := newHeightIncr(4) + v2 := newHeightIncr(g, 4) rh.add(v2) Equal(t, 2, rh.minHeight) @@ -420,11 +430,12 @@ func Test_recomputeHeap_sanityCheck_badItemHeight(t *testing.T) { } func Test_recomputeHeap_clear(t *testing.T) { + g := New() rh := newRecomputeHeap(32) - n50 := newHeightIncr(5) - n60 := newHeightIncr(6) - n70 := newHeightIncr(7) + n50 := newHeightIncr(g, 5) + n60 := newHeightIncr(g, 6) + n70 := newHeightIncr(g, 7) rh.add(n50) rh.add(n60) diff --git a/scope.go b/scope.go index 6cd724a..7996a33 100644 --- a/scope.go +++ b/scope.go @@ -1,5 +1,7 @@ package incr +import "fmt" + // WithinScope updates a node's createdIn scope to reflect a new // inner-most bind scope applied by a bind. // @@ -26,21 +28,29 @@ func WithinScope[A INode](scope Scope, node A) A { type Scope interface { isRootScope() bool scopeGraph() *Graph + scopeHeight() int + fmt.Stringer } // BindScope is the scope that nodes are created in. // // Its either nil or the most recent bind. type bindScope struct { + input INode bind INode rhsNodes []INode } func (bs *bindScope) isRootScope() bool { return false } func (bs *bindScope) scopeGraph() *Graph { return bs.bind.Node().graph } +func (bs *bindScope) scopeHeight() int { return bs.input.Node().height } + +func (bs *bindScope) String() string { + return fmt.Sprintf("{%v}", bs.bind) +} func maybeRemoveScopeNode(scope Scope, node INode) { - if scope != nil && scope.isRootScope() { + if scope == nil || scope != nil && scope.isRootScope() { return } if typed, ok := scope.(*bindScope); ok && typed != nil { diff --git a/set.go b/set.go index 53e7a2b..26e175a 100644 --- a/set.go +++ b/set.go @@ -1,5 +1,13 @@ package incr +func newSet[T comparable](values ...T) set[T] { + output := make(set[T], len(values)) + for _, v := range values { + output[v] = struct{}{} + } + return output +} + type set[T comparable] map[T]struct{} func (s set[T]) has(t T) (ok bool) { @@ -18,3 +26,11 @@ func (s set[T]) copy() set[T] { } return output } + +func (s set[T]) keys() (out []T) { + out = make([]T, 0, len(s)) + for k := range s { + out = append(out, k) + } + return +} diff --git a/stabilize.go b/stabilize.go index 890e498..cce7ade 100644 --- a/stabilize.go +++ b/stabilize.go @@ -33,9 +33,6 @@ func (graph *Graph) Stabilize(ctx context.Context) (err error) { if err != nil { break } - if err = graph.recomputeHeights(); err != nil { - break - } } graph.recomputeHeap.add(immediateRecompute...) return diff --git a/stabilize_test.go b/stabilize_test.go index fbaa711..0735bfb 100644 --- a/stabilize_test.go +++ b/stabilize_test.go @@ -155,75 +155,6 @@ func Test_Stabilize_updateHandlers(t *testing.T) { Equal(t, 2, updates) } -func Test_Stabilize_observedHandlers(t *testing.T) { - ctx := testContext() - g := New() - - v0 := Var(g, "foo") - v1 := Var(g, "bar") - m0 := Map2(g, v0, v1, func(a, b string) string { - return a + " " + b - }) - - var observes int - m0.Node().OnObserved(func(IObserver) { - observes++ - }) - - _ = Observe(g, m0) - - err := g.Stabilize(ctx) - Nil(t, err) - Equal(t, 1, observes) - - v0.Set("not foo") - err = g.Stabilize(ctx) - Nil(t, err) - Equal(t, 1, observes) - - _ = Observe(g, m0) - Equal(t, 2, observes) -} - -func Test_Stabilize_unobservedHandlers(t *testing.T) { - ctx := testContext() - g := New() - - v0 := Var(g, "foo") - v1 := Var(g, "bar") - m0 := Map2(g, v0, v1, func(a, b string) string { - return a + " " + b - }) - - var observes, unobserves int - m0.Node().OnObserved(func(IObserver) { - observes++ - }) - m0.Node().OnUnobserved(func(IObserver) { - unobserves++ - }) - - o0 := Observe(g, m0) - - err := g.Stabilize(ctx) - Nil(t, err) - Equal(t, 1, observes) - Equal(t, 0, unobserves) - - v0.Set("not foo") - err = g.Stabilize(ctx) - Nil(t, err) - Equal(t, 1, observes) - - _ = Observe(g, m0) - Equal(t, 2, observes) - Equal(t, 0, unobserves) - - o0.Unobserve(ctx) - Equal(t, 2, observes) - Equal(t, 1, unobserves) -} - func Test_Stabilize_unevenHeights(t *testing.T) { ctx := testContext() g := New() @@ -524,18 +455,16 @@ func Test_Stabilize_Bind(t *testing.T) { _ = Observe(g, mb) - Equal(t, true, g.IsObserving(sw)) + Equal(t, true, g.Has(sw)) err := g.Stabilize(ctx) Nil(t, err) - Equal(t, false, g.IsObserving(i0)) - Equal(t, false, g.IsObserving(m0)) - Nil(t, i0.Node().graph, "i0 should not be in the graph after the first stabilization") - Nil(t, m0.Node().graph, "m0 should not be in the graph after the first stabilization") + Equal(t, true, g.Has(i0)) + Equal(t, true, g.Has(m0)) - Equal(t, true, g.IsObserving(i1)) - Equal(t, true, g.IsObserving(m1)) + Equal(t, true, g.Has(i1)) + Equal(t, true, g.Has(m1)) NotNil(t, i1.Node().graph, "i1 should be in the graph after the first stabilization") NotNil(t, m1.Node().graph, "m1 should be in the graph after the first stabilization") @@ -547,15 +476,13 @@ func Test_Stabilize_Bind(t *testing.T) { err = g.Stabilize(ctx) Nil(t, err) - Equal(t, true, g.IsObserving(i0)) - Equal(t, true, g.IsObserving(m0)) + Equal(t, true, g.Has(i0)) + Equal(t, true, g.Has(m0)) NotNil(t, i0.Node().graph, "i0 should be in the graph after the second stabilization") NotNil(t, m0.Node().graph, "m0 should be in the graph after the second stabilization") - Equal(t, false, g.IsObserving(i1)) - Equal(t, false, g.IsObserving(m1)) - Nil(t, i1.Node().graph, "i1 should not be in the graph after the second stabilization") - Nil(t, m1.Node().graph, "m1 should not be in the graph after the second stabilization") + Equal(t, true, g.Has(i1)) + Equal(t, true, g.Has(m1)) Equal(t, "foo-moo-baz", mb.Value()) } @@ -590,7 +517,6 @@ func Test_Stabilize_BindIf(t *testing.T) { err = g.Stabilize(ctx) Nil(t, err) - Nil(t, i1.Node().graph, "i0 should be in the graph after the third stabilization") NotNil(t, i0.Node().graph, "i1 should not be in the graph after the third stabilization") Equal(t, "foo", b.Value()) @@ -708,7 +634,7 @@ func Test_Stabilize_Bind4(t *testing.T) { Equal(t, "xaxbxcxd", o.Value()) } -func Test_Stabilize_cutoff(t *testing.T) { +func Test_Stabilize_Cutoff(t *testing.T) { ctx := testContext() g := New() @@ -755,7 +681,7 @@ func Test_Stabilize_cutoff(t *testing.T) { Equal(t, 13.26, output.Value()) } -func Test_Stabilize_cutoffContext(t *testing.T) { +func Test_Stabilize_CutoffContext(t *testing.T) { ctx := testContext() g := New() input := Var(g, 3.14) @@ -803,7 +729,7 @@ func Test_Stabilize_cutoffContext(t *testing.T) { Equal(t, 13.26, output.Value()) } -func Test_Stabilize_cutoffContext_error(t *testing.T) { +func Test_Stabilize_CutoffContext_error(t *testing.T) { ctx := testContext() g := New() input := Var(g, 3.14) @@ -849,7 +775,7 @@ func Test_Stabilize_cutoffContext_error(t *testing.T) { Equal(t, 0, output.Value()) } -func Test_Stabilize_cutoff2(t *testing.T) { +func Test_Stabilize_Cutoff2(t *testing.T) { ctx := testContext() g := New() @@ -912,7 +838,7 @@ func Test_Stabilize_cutoff2(t *testing.T) { Equal(t, 13.26, output.Value()) } -func Test_Stabilize_cutoff2Context_error(t *testing.T) { +func Test_Stabilize_Cutoff2Context_error(t *testing.T) { ctx := testContext() g := New() epsilon := Var(g, 0.1) @@ -960,7 +886,7 @@ func Test_Stabilize_cutoff2Context_error(t *testing.T) { Equal(t, 0, output.Value()) } -func Test_Stabilize_watch(t *testing.T) { +func Test_Stabilize_Watch(t *testing.T) { ctx := testContext() g := New() @@ -1270,7 +1196,7 @@ func Test_Stabilize_MapNContext_error(t *testing.T) { Equal(t, 0, mn.Value()) } -func Test_Stabilize_func(t *testing.T) { +func Test_Stabilize_Func(t *testing.T) { ctx := testContext() g := New() @@ -1302,7 +1228,7 @@ func Test_Stabilize_func(t *testing.T) { Equal(t, "not hello world!", m.Value()) } -func Test_Stabilize_foldMap(t *testing.T) { +func Test_Stabilize_FoldMap(t *testing.T) { ctx := testContext() g := New() @@ -1324,7 +1250,7 @@ func Test_Stabilize_foldMap(t *testing.T) { Equal(t, 21, mf.Value()) } -func Test_Stabilize_foldLeft(t *testing.T) { +func Test_Stabilize_FoldLeft(t *testing.T) { ctx := testContext() g := New() @@ -1346,7 +1272,7 @@ func Test_Stabilize_foldLeft(t *testing.T) { Equal(t, "123456", mf.Value()) } -func Test_Stabilize_foldRight(t *testing.T) { +func Test_Stabilize_FoldRight(t *testing.T) { ctx := testContext() g := New() @@ -1373,7 +1299,7 @@ func Test_Stabilize_foldRight(t *testing.T) { Equal(t, "654321654321", mf.Value()) } -func Test_Stabilize_freeze(t *testing.T) { +func Test_Stabilize_Freeze(t *testing.T) { ctx := testContext() g := New() @@ -1395,7 +1321,7 @@ func Test_Stabilize_freeze(t *testing.T) { Equal(t, "hello", fv.Value()) } -func Test_Stabilize_always_cutoff(t *testing.T) { +func Test_Stabilize_Always_Cutoff(t *testing.T) { ctx := testContext() g := New() @@ -1430,7 +1356,7 @@ func Test_Stabilize_always_cutoff(t *testing.T) { Equal(t, "test-2", o.Value()) } -func Test_Stabilize_always_cutoff_error(t *testing.T) { +func Test_Stabilize_Always_Cutoff_error(t *testing.T) { ctx := testContext() g := New() @@ -1506,7 +1432,7 @@ func Test_Stabilize_handlers(t *testing.T) { Equal(t, true, endWasBlueDye) } -func Test_Stabilize_bindCombination(t *testing.T) { +func Test_Stabilize_Bind_jsCombination(t *testing.T) { ctx := testContext() g := New() diff --git a/timer.go b/timer.go index 8554ea4..4ccbdfc 100644 --- a/timer.go +++ b/timer.go @@ -11,14 +11,14 @@ import ( // When it stabilizes, it stabilizes the input node, and assumes its // value in observation. func Timer[A any](scope Scope, input Incr[A], every time.Duration) TimerIncr[A] { - t := &timerIncr[A]{ + t := WithinScope(scope, &timerIncr[A]{ n: NewNode("timer"), clockSource: func(_ context.Context) time.Time { return time.Now().UTC() }, every: every, input: input, - } + }) Link(t, input) - return WithinScope(scope, t) + return t } // TimerIncr is the exported methods of a Timer. diff --git a/unlink.go b/unlink.go index e2e30c4..8b8971f 100644 --- a/unlink.go +++ b/unlink.go @@ -5,4 +5,5 @@ package incr func Unlink(child, input INode) { child.Node().removeParent(input.Node().id) input.Node().removeChild(child.Node().id) + child.Node().createdIn.scopeGraph().checkIfUnnecessary(child) } diff --git a/watch.go b/watch.go index abda124..7be13f8 100644 --- a/watch.go +++ b/watch.go @@ -6,14 +6,14 @@ import ( ) // Watch returns a new watch incremental that tracks -// values for a given incremental. +// values for a given incremental each time it stabilizes. func Watch[A any](scope Scope, i Incr[A]) WatchIncr[A] { - o := &watchIncr[A]{ + o := WithinScope(scope, &watchIncr[A]{ n: NewNode("watch"), incr: i, - } + }) Link(o, i) - return WithinScope(scope, o) + return o } // WatchIncr is a type that implements the watch interface.