Skip to content

Commit

Permalink
Refactors how observers are treated; they only mark leaf nodes as nec…
Browse files Browse the repository at this point in the history
…essary now. (#13)
  • Loading branch information
wcharczuk authored Feb 12, 2024
1 parent a5e4974 commit 602b869
Show file tree
Hide file tree
Showing 52 changed files with 707 additions and 1,026 deletions.
18 changes: 12 additions & 6 deletions adjust_heights_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,16 @@ func (ah *adjustHeightsHeap) setHeight(node INode, height int) error {
return nil
}

func (ah *adjustHeightsHeap) ensureHeightRequirement(child, parent INode) error {
func (ah *adjustHeightsHeap) ensureHeightRequirement(originalChild, originalParent, child, parent INode) error {
ah.mu.Lock()
defer ah.mu.Unlock()
return ah.ensureHeightRequirementUnsafe(child, parent)
return ah.ensureHeightRequirementUnsafe(originalChild, originalParent, child, parent)
}

func (ah *adjustHeightsHeap) ensureHeightRequirementUnsafe(child, parent INode) error {
func (ah *adjustHeightsHeap) ensureHeightRequirementUnsafe(originalChild, originalParent, child, parent INode) error {
if originalParent.Node().id == child.Node().id {
return fmt.Errorf("cycle detected at %v", child)
}
if parent.Node().height >= child.Node().height {
if err := ah.setHeight(child, parent.Node().height+1); err != nil {
return err
Expand All @@ -130,14 +133,17 @@ func (ah *adjustHeightsHeap) ensureHeightRequirementUnsafe(child, parent INode)
return nil
}

func (ah *adjustHeightsHeap) adjustHeights(rh *recomputeHeap) error {
func (ah *adjustHeightsHeap) adjustHeights(rh *recomputeHeap, originalChild, originalParent INode) error {
ah.mu.Lock()
defer ah.mu.Unlock()
if err := ah.ensureHeightRequirementUnsafe(originalChild, originalParent, originalChild, originalParent); err != nil {
return err
}
for len(ah.lookup) > 0 {
node, _ := ah.removeMinUnsafe()
rh.fix(node.Node().id)
for _, child := range node.Node().children {
if err := ah.ensureHeightRequirementUnsafe(child, node); err != nil {
if err := ah.ensureHeightRequirementUnsafe(originalChild, originalParent, child, node); err != nil {
return err
}
}
Expand All @@ -146,7 +152,7 @@ func (ah *adjustHeightsHeap) adjustHeights(rh *recomputeHeap) error {
if scopeOK {
for _, nodeOnRight := range scope.rhsNodes {
if node.Node().graph.isNecessary(nodeOnRight) {
if err := ah.ensureHeightRequirementUnsafe(node, nodeOnRight); err != nil {
if err := ah.ensureHeightRequirementUnsafe(originalChild, originalParent, node, nodeOnRight); err != nil {
return err
}
}
Expand Down
6 changes: 3 additions & 3 deletions always.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package incr
// Always returns an incremental that is always stale and will be
// marked for recomputation.
func Always[A any](scope Scope, input Incr[A]) Incr[A] {
a := &alwaysIncr[A]{
a := WithinScope(scope, &alwaysIncr[A]{
n: NewNode("always"),
input: input,
}
})
Link(a, input)
return WithinScope(scope, a)
return a
}

// AlwaysIncr is a type that implements the always stale incremental.
Expand Down
1 change: 0 additions & 1 deletion always_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func Test_Always(t *testing.T) {
m1.Node().OnUpdate(func(_ context.Context) {
updates++
})

o := Observe(g, m1)

ctx := testContext()
Expand Down
23 changes: 6 additions & 17 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func Benchmark_Stabilize_connectedGraph_with_nestedBinds_128(b *testing.B) {
benchmarkConnectedGraphWithNestedBinds(128, b)
}

func benchmarkSize(size int, b *testing.B) {
func makeBenchmarkGraph(size int) (*Graph, []Incr[string]) {
graph := New()
nodes := make([]Incr[string], size)
for x := 0; x < size; x++ {
Expand All @@ -124,7 +124,11 @@ func benchmarkSize(size int, b *testing.B) {
}

_ = Observe(graph, nodes[len(nodes)-1])
return graph, nodes
}

func benchmarkSize(size int, b *testing.B) {
graph, nodes := makeBenchmarkGraph(size)
// this is what we care about
ctx := context.Background()
b.ResetTimer()
Expand Down Expand Up @@ -152,22 +156,7 @@ func benchmarkSize(size int, b *testing.B) {
}

func benchmarkParallelSize(size int, b *testing.B) {
graph := New()
nodes := make([]Incr[string], size)
for x := 0; x < size; x++ {
nodes[x] = Var(graph, fmt.Sprintf("var_%d", x))
}

var cursor int
for x := size; x > 0; x >>= 1 {
for y := 0; y < x-1; y += 2 {
n := Map2(graph, nodes[cursor+y], nodes[cursor+y+1], concat)
nodes = append(nodes, n)
}
cursor += x
}

_ = Observe(graph, nodes[0])
graph, nodes := makeBenchmarkGraph(size)

// this is what we care about
ctx := context.Background()
Expand Down
85 changes: 40 additions & 45 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ func Bind[A, B any](scope Scope, input Incr[A], fn func(Scope, A) Incr[B]) BindI
// If an error returned, the bind is aborted, the error listener(s) will fire for the node, and the
// computation will stop.
func BindContext[A, B any](scope Scope, input Incr[A], fn func(context.Context, Scope, A) (Incr[B], error)) BindIncr[B] {
o := &bindIncr[A, B]{
o := WithinScope(scope, &bindIncr[A, B]{
n: NewNode("bind"),
input: input,
fn: fn,
}
})
o.scope = &bindScope{
bind: o,
input: input,
bind: o,
}
Link(o, input)
return WithinScope(scope, o)
return o
}

// BindIncr is a node that implements Bind, which can dynamically swap out
Expand Down Expand Up @@ -100,7 +101,7 @@ func (b *bindIncr[A, B]) Scope() Scope {
}

func (b *bindIncr[A, B]) didInputChange() bool {
return b.input.Node().changedAt >= b.n.recomputedAt
return b.input.Node().changedAt >= b.n.changedAt
}

func (b *bindIncr[A, B]) Stabilize(ctx context.Context) error {
Expand All @@ -111,6 +112,7 @@ func (b *bindIncr[A, B]) Stabilize(ctx context.Context) error {
// we do want to propagate changes to the bound node to the bind
// node's children however, so some trickery is involved.
if !b.didInputChange() {
TracePrintf(ctx, "%v input unchanged", b)
// NOTE (wc): ok so this is a tangle.
// we halt computation based on boundAt for nodes that
// set their bound at. So if our bound node triggered
Expand All @@ -129,20 +131,25 @@ func (b *bindIncr[A, B]) Stabilize(ctx context.Context) error {
if b.bound != nil && newIncr != nil {
if b.bound.Node().id != newIncr.Node().id {
bindChanged = true
b.unlinkOldBound(ctx, b.n.Observers()...)
b.unlinkOldBound(ctx, b.n.observers...)
if err := b.linkNewBound(ctx, newIncr); err != nil {
return err
}
} else {
bindChanged = b.bound.Node().changedAt > b.n.boundAt
TracePrintf(ctx, "%v bound to same node after stabilization", b)
}
} else if newIncr != nil {
bindChanged = true
b.linkBindChange(ctx)
if err := b.linkBindChange(ctx); err != nil {
return err
}
if err := b.linkNewBound(ctx, newIncr); err != nil {
return err
}
} else if b.bound != nil {
bindChanged = true
b.unlinkOldBound(ctx, b.n.Observers()...)
b.unlinkOldBound(ctx, b.n.observers...)
b.unlinkBindChange(ctx)
}
if bindChanged {
Expand All @@ -158,54 +165,38 @@ func (b *bindIncr[A, B]) Unobserve(ctx context.Context, observers ...IObserver)

func (b *bindIncr[A, B]) Link(ctx context.Context) (err error) {
if b.bound != nil {
_ = b.n.graph.adjustHeightsHeap.ensureHeightRequirement(b, b.bound)
for _, n := range b.scope.rhsNodes {
if typed, ok := n.(IBind); ok {
if err = typed.Link(ctx); err != nil {
return
}
}
}
err = b.n.graph.adjustHeightsHeap.adjustHeights(b.n.graph.recomputeHeap, b, b.bound)
if err != nil {
return
}
}
return
}

func (b *bindIncr[A, B]) linkBindChange(ctx context.Context) {
b.bindChange = &bindChangeIncr[A, B]{
func (b *bindIncr[A, B]) linkBindChange(ctx context.Context) error {
b.bindChange = WithinScope(b.n.createdIn, &bindChangeIncr[A, B]{
n: NewNode("bind-lhs-change"),
lhs: b.input,
rhs: b.bound,
}
})
if b.n.label != "" {
b.bindChange.n.SetLabel(fmt.Sprintf("%s-change", b.n.label))
}
b.bindChange.n.createdIn = b.n.createdIn
Link(b.bindChange, b.input)
b.n.graph.observeSingleNode(b.bindChange, b.n.Observers()...)
}

func (b *bindIncr[A, B]) unlinkBindChange(ctx context.Context) {
if b.bindChange != nil {
if b.bound != nil {
Unlink(b.bound, b.bindChange)
}
Unlink(b.bindChange, b.input)

// NOTE (wc): we don't do a """typical""" unobserve here because we
// really don't care; if it's time to unlink our bind change, it's our
// bind change, there is no way to observe it directly, so we'll just
// shoot it in the face ourselves.
b.n.graph.removeNodeFromGraph(b.bindChange)
b.bindChange = nil
}
return nil
}

func (b *bindIncr[A, B]) linkNewBound(ctx context.Context, newIncr Incr[B]) (err error) {
b.bound = newIncr
Link(b.bound, b.bindChange)
Link(b, b.bound)
b.n.graph.observeNodes(b.bound, b.n.Observers()...)
_ = b.n.graph.adjustHeightsHeap.ensureHeightRequirement(b, b.bound)
for _, n := range b.scope.rhsNodes {
if typed, ok := n.(IBind); ok {
if err = typed.Link(ctx); err != nil {
Expand All @@ -217,27 +208,31 @@ func (b *bindIncr[A, B]) linkNewBound(ctx context.Context, newIncr Incr[B]) (err
return
}

func (b *bindIncr[A, B]) unlinkBindChange(ctx context.Context) {
if b.bindChange != nil {
if b.bound != nil {
Unlink(b.bound, b.bindChange)
}
Unlink(b.bindChange, b.input)
// NOTE (wc): we don't do a """typical""" unobserve here because we
// really don't care; if it's time to unlink our bind change, it's our
// bind change, there is no way to observe it directly, so we'll just
// shoot it in the face ourselves.
if !b.n.graph.isNecessary(b.bindChange) {
b.bindChange = nil
}
}
}

func (b *bindIncr[A, B]) unlinkOldBound(ctx context.Context, observers ...IObserver) {
if b.bound != nil {
TracePrintf(ctx, "%v unbinding old rhs %v", b, b.bound)
Unlink(b.bound, b.bindChange)
b.removeNodesFromScope(ctx, b.scope, observers...)
Unlink(b, b.bound)
b.n.graph.unobserveNodes(ctx, b.bound, observers...)
TracePrintf(ctx, "%v unbound old rhs %v", b, b.bound)
b.bound = nil
}
}

func (b *bindIncr[A, B]) removeNodesFromScope(ctx context.Context, scope *bindScope, observers ...IObserver) {
for _, n := range scope.rhsNodes {
n.Node().createdIn = nil
if typed, ok := n.(IUnobserve); ok {
typed.Unobserve(ctx, observers...)
}
}
scope.rhsNodes = nil
}

func (b *bindIncr[A, B]) String() string {
return b.n.String()
}
Expand Down
2 changes: 1 addition & 1 deletion bind_if.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import "context"
func BindIf[A any](scope Scope, p Incr[bool], fn func(context.Context, Scope, bool) (Incr[A], error)) BindIncr[A] {
b := BindContext[bool, A](scope, p, fn).(*bindIncr[bool, A])
b.Node().SetKind("bind_if")
return WithinScope(scope, b)
return b
}
Loading

0 comments on commit 602b869

Please sign in to comment.