From a7eb215a6f26d56e0ab03240a19e7445d82c58c8 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Tue, 11 Aug 2020 03:51:46 +0200 Subject: [PATCH 01/13] p2p/nodestate: add unit test for correct callback order --- p2p/nodestate/nodestate_test.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go index f6ff3ffc07d8..935062837a88 100644 --- a/p2p/nodestate/nodestate_test.go +++ b/p2p/nodestate/nodestate_test.go @@ -387,3 +387,34 @@ func TestDuplicatedFlags(t *testing.T) { clock.Run(2 * time.Second) check(flags[0], Flags{}, true) } + +func TestCallbackOrder(t *testing.T) { + mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} + + s, flags, _ := testSetup([]bool{false, false, false, false}, nil) + ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) + + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { + if newState.Equals(flags[0]) { + ns.SetState(n, flags[1], Flags{}, 0) + ns.SetState(n, flags[2], Flags{}, 0) + } + }) + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { + if newState.Equals(flags[1]) { + ns.SetState(n, flags[3], Flags{}, 0) + } + }) + lastState := Flags{} + ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) { + if !oldState.Equals(lastState) { + t.Fatalf("Wrong callback order") + } + lastState = newState + }) + + ns.Start() + defer ns.Stop() + + ns.SetState(testNode(1), flags[0], Flags{}, 0) +} From 5698c96c7024d14ea8a40415eb4a7376c7bcf219 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Tue, 11 Aug 2020 03:53:38 +0200 Subject: [PATCH 02/13] p2p/nodestate: ensure correct callback order --- p2p/nodestate/nodestate.go | 168 +++++++++++++++++++++++++++---------- 1 file changed, 124 insertions(+), 44 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 7091281aeaef..469fd68120a9 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -18,6 +18,7 @@ package nodestate import ( "errors" + "fmt" "reflect" "sync" "time" @@ -32,6 +33,8 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +const debugPrints = false + type ( // NodeStateMachine connects different system components operating on subsets of // network nodes. Node states are represented by 64 bit vectors with each bit assigned @@ -60,6 +63,8 @@ type ( dbNodeKey []byte nodes map[enode.ID]*nodeInfo offlineCallbackList []offlineCallback + callbackCount int + callbackWait *sync.Cond // Registered state flags or fields. Modifications are allowed // only when the node state machine has not been started. @@ -128,11 +133,12 @@ type ( // nodeInfo contains node state, fields and state timeouts nodeInfo struct { - node *enode.Node - state bitMask - timeouts []*nodeStateTimeout - fields []interface{} - db, dirty bool + node *enode.Node + state bitMask + timeouts []*nodeStateTimeout + fields []interface{} + pendingCallbacks []func() + db, dirty bool } nodeInfoEnc struct { @@ -158,7 +164,7 @@ type ( } offlineCallback struct { - node *enode.Node + node *nodeInfo state bitMask fields []interface{} } @@ -319,10 +325,11 @@ func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Cloc nodes: make(map[enode.ID]*nodeInfo), fields: make([]*fieldInfo, len(setup.fields)), } + ns.callbackWait = sync.NewCond(&ns.lock) stateNameMap := make(map[string]int) for index, flag := range setup.flags { if _, ok := stateNameMap[flag.name]; ok { - panic("Node state flag name collision") + panic("Node state flag name collision: " + flag.name) } stateNameMap[flag.name] = index if flag.persistent { @@ -332,7 +339,7 @@ func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Cloc fieldNameMap := make(map[string]int) for index, field := range setup.fields { if _, ok := fieldNameMap[field.name]; ok { - panic("Node field name collision") + panic("Node field name collision: " + field.name) } ns.fields[index] = &fieldInfo{fieldDefinition: field} fieldNameMap[field.name] = index @@ -415,10 +422,13 @@ func (ns *NodeStateMachine) Start() { // Stop stops the state machine and saves its state if a database was supplied func (ns *NodeStateMachine) Stop() { ns.lock.Lock() + if ns.callbackCount != 0 { + ns.callbackWait.Wait() + } for _, node := range ns.nodes { fields := make([]interface{}, len(node.fields)) copy(fields, node.fields) - ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) } ns.stopped = true if ns.db != nil { @@ -491,7 +501,7 @@ func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { node.state = enc.State fields := make([]interface{}, len(node.fields)) copy(fields, node.fields) - ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node.node, node.state, fields}) + ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) } @@ -620,6 +630,9 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, } oldState := node.state newState := (node.state & (^reset)) | set + if debugPrints { + fmt.Println("SetState", n.ID(), "old", Flags{oldState, ns.setup}, "new", Flags{newState, ns.setup}, "set", setFlags, "reset", resetFlags, "timeout", timeout) + } changed := oldState ^ newState node.state = newState @@ -646,53 +659,104 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, node.dirty = true } } - ns.lock.Unlock() - // call state update subscription callbacks without holding the mutex - for _, sub := range ns.stateSubs { - if changed&sub.mask != 0 { - sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) + callback := func() { + for _, sub := range ns.stateSubs { + if changed&sub.mask != 0 { + sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) + } } - } - if newState == 0 { - // call field subscriptions for discarded fields - for i, v := range node.fields { - if v != nil { - f := ns.fields[i] - if len(f.subs) > 0 { - for _, cb := range f.subs { - cb(n, Flags{setup: ns.setup}, v, nil) + if newState == 0 { + // call field subscriptions for discarded fields + for i, v := range node.fields { + if v != nil { + if debugPrints { + fmt.Println("discardField", n.ID(), ns.setup.fields[i].name, v) + } + f := ns.fields[i] + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{setup: ns.setup}, v, nil) + } } } } } } + callNow := node.pendingCallbacks == nil + node.pendingCallbacks = append(node.pendingCallbacks, callback) + ns.callbackCount++ + list := node.pendingCallbacks + ns.lock.Unlock() + // call state update subscription callbacks without holding the mutex + if callNow { + ns.processCallbacks(node, list) + } +} + +// processCallbacks runs pending callbacks of a given node in a guaranteed correct order. +// Callbacks resulting from a state/field change performed in a previous callback are always +// put at the end of the pending list and therefore processed after all callbacks resulting +// from the previous state/field change. +func (ns *NodeStateMachine) processCallbacks(node *nodeInfo, list []func()) { + for list != nil { + for _, cb := range list { + cb() + } + ns.lock.Lock() + node.pendingCallbacks = node.pendingCallbacks[len(list):] + if len(node.pendingCallbacks) == 0 { + node.pendingCallbacks = nil + } + ns.callbackCount -= len(list) + if ns.callbackCount == 0 { + ns.callbackWait.Signal() + } + list = node.pendingCallbacks + ns.lock.Unlock() + } } // offlineCallbacks calls state update callbacks at startup or shutdown func (ns *NodeStateMachine) offlineCallbacks(start bool) { for _, cb := range ns.offlineCallbackList { - for _, sub := range ns.stateSubs { - offState := offlineState & sub.mask - onState := cb.state & sub.mask - if offState != onState { - if start { - sub.callback(cb.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) - } else { - sub.callback(cb.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) - } - } - } - for i, f := range cb.fields { - if f != nil && ns.fields[i].subs != nil { - for _, fsub := range ns.fields[i].subs { + cb := cb + callback := func() { + for _, sub := range ns.stateSubs { + offState := offlineState & sub.mask + onState := cb.state & sub.mask + if offState != onState { if start { - fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) + sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) } else { - fsub(cb.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) + sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) + } + } + } + for i, f := range cb.fields { + if f != nil && ns.fields[i].subs != nil { + for _, fsub := range ns.fields[i].subs { + if start { + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) + } else { + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) + } } } } } + ns.lock.Lock() + if cb.node.pendingCallbacks != nil { + panic("Fatal: unfinished callback") + } + cb.node.pendingCallbacks = []func(){callback} + ns.callbackCount++ + ns.lock.Unlock() + } + for _, cb := range ns.offlineCallbackList { + ns.lock.Lock() + list := cb.node.pendingCallbacks + ns.lock.Unlock() + ns.processCallbacks(cb.node, list) } ns.offlineCallbackList = nil } @@ -723,6 +787,9 @@ func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time ns.removeTimeouts(node, mask) t := &nodeStateTimeout{mask: mask} t.timer = ns.clock.AfterFunc(timeout, func() { + if debugPrints { + fmt.Println("timeout", n.ID(), Flags{mask, ns.setup}) + } ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) }) node.timeouts = append(node.timeouts, t) @@ -792,6 +859,9 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface return errors.New("invalid field type") } oldValue := node.fields[fieldIndex] + if debugPrints { + fmt.Println("SetField", n.ID(), Flags{node.state, ns.setup}, field.setup.fields[field.index].name, oldValue, value) + } if value == oldValue { ns.lock.Unlock() return nil @@ -802,12 +872,22 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface } state := node.state - ns.lock.Unlock() - if len(f.subs) > 0 { - for _, cb := range f.subs { - cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) + callback := func() { + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) + } } } + callNow := node.pendingCallbacks == nil + node.pendingCallbacks = append(node.pendingCallbacks, callback) + ns.callbackCount++ + list := node.pendingCallbacks + ns.lock.Unlock() + // call field update subscription callbacks without holding the mutex + if callNow { + ns.processCallbacks(node, list) + } return nil } From 5a1413518ea6db4ce66aedae8934cd4e3a1aaacf Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Thu, 13 Aug 2020 17:08:07 +0200 Subject: [PATCH 03/13] p2p/nodestate: add Caller --- p2p/nodestate/nodestate.go | 99 +++++++++++++++++---------------- p2p/nodestate/nodestate_test.go | 88 ++++++++++++++--------------- 2 files changed, 94 insertions(+), 93 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 469fd68120a9..4a32c0f05ab4 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -121,24 +121,27 @@ type ( // of node flags with each bit assigned to a flag index (LSB represents flag 0). bitMask uint64 + Caller struct { + pending []func() + } + // StateCallback is a subscription callback which is called when one of the // state flags that is included in the subscription state mask is changed. // Note: oldState and newState are also masked with the subscription mask so only // the relevant bits are included. - StateCallback func(n *enode.Node, oldState, newState Flags) + StateCallback func(n *enode.Node, oldState, newState Flags, caller *Caller) // FieldCallback is a subscription callback which is called when the value of // a specific field is changed. - FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}) + FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}, caller *Caller) // nodeInfo contains node state, fields and state timeouts nodeInfo struct { - node *enode.Node - state bitMask - timeouts []*nodeStateTimeout - fields []interface{} - pendingCallbacks []func() - db, dirty bool + node *enode.Node + state bitMask + timeouts []*nodeStateTimeout + fields []interface{} + db, dirty bool } nodeInfoEnc struct { @@ -610,7 +613,7 @@ func (ns *NodeStateMachine) Persist(n *enode.Node) error { // It only returns after all subsequent immediate changes (including those changed by the // callbacks) have been processed. If a flag with a timeout is set again, the operation // removes or replaces the existing timeout. -func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { +func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration, caller *Caller) { ns.lock.Lock() ns.checkStarted() if ns.stopped { @@ -662,7 +665,7 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, callback := func() { for _, sub := range ns.stateSubs { if changed&sub.mask != 0 { - sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) + sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}, caller) } } if newState == 0 { @@ -675,21 +678,25 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, f := ns.fields[i] if len(f.subs) > 0 { for _, cb := range f.subs { - cb(n, Flags{setup: ns.setup}, v, nil) + cb(n, Flags{setup: ns.setup}, v, nil, caller) } } } } } } - callNow := node.pendingCallbacks == nil - node.pendingCallbacks = append(node.pendingCallbacks, callback) - ns.callbackCount++ - list := node.pendingCallbacks + callNow := caller == nil + if callNow { + caller = &Caller{} + } else { + caller.pending = append(caller.pending, callback) + ns.callbackCount++ + } ns.lock.Unlock() // call state update subscription callbacks without holding the mutex if callNow { - ns.processCallbacks(node, list) + callback() + ns.processCallbacks(caller) } } @@ -697,27 +704,26 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, // Callbacks resulting from a state/field change performed in a previous callback are always // put at the end of the pending list and therefore processed after all callbacks resulting // from the previous state/field change. -func (ns *NodeStateMachine) processCallbacks(node *nodeInfo, list []func()) { - for list != nil { +func (ns *NodeStateMachine) processCallbacks(caller *Caller) { + for len(caller.pending) != 0 { + list := caller.pending for _, cb := range list { cb() } ns.lock.Lock() - node.pendingCallbacks = node.pendingCallbacks[len(list):] - if len(node.pendingCallbacks) == 0 { - node.pendingCallbacks = nil - } + caller.pending = caller.pending[len(list):] ns.callbackCount -= len(list) if ns.callbackCount == 0 { ns.callbackWait.Signal() } - list = node.pendingCallbacks ns.lock.Unlock() } } // offlineCallbacks calls state update callbacks at startup or shutdown func (ns *NodeStateMachine) offlineCallbacks(start bool) { + ns.lock.Lock() + caller := &Caller{} for _, cb := range ns.offlineCallbackList { cb := cb callback := func() { @@ -726,9 +732,9 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { onState := cb.state & sub.mask if offState != onState { if start { - sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) + sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}, caller) } else { - sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) + sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}, caller) } } } @@ -736,29 +742,20 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { if f != nil && ns.fields[i].subs != nil { for _, fsub := range ns.fields[i].subs { if start { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f, caller) } else { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil, caller) } } } } } - ns.lock.Lock() - if cb.node.pendingCallbacks != nil { - panic("Fatal: unfinished callback") - } - cb.node.pendingCallbacks = []func(){callback} + caller.pending = append(caller.pending, callback) ns.callbackCount++ - ns.lock.Unlock() - } - for _, cb := range ns.offlineCallbackList { - ns.lock.Lock() - list := cb.node.pendingCallbacks - ns.lock.Unlock() - ns.processCallbacks(cb.node, list) } ns.offlineCallbackList = nil + ns.lock.Unlock() + ns.processCallbacks(caller) } // AddTimeout adds a node state timeout associated to the given state flag(s). @@ -790,7 +787,7 @@ func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time if debugPrints { fmt.Println("timeout", n.ID(), Flags{mask, ns.setup}) } - ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) + ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0, nil) }) node.timeouts = append(node.timeouts, t) if mask&ns.saveFlags != 0 { @@ -839,7 +836,7 @@ func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { } // SetField sets the given field of the given node -func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { +func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}, caller *Caller) error { ns.lock.Lock() ns.checkStarted() if ns.stopped { @@ -875,18 +872,22 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface callback := func() { if len(f.subs) > 0 { for _, cb := range f.subs { - cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) + cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value, caller) } } } - callNow := node.pendingCallbacks == nil - node.pendingCallbacks = append(node.pendingCallbacks, callback) - ns.callbackCount++ - list := node.pendingCallbacks + callNow := caller == nil + if callNow { + caller = &Caller{} + } else { + caller.pending = append(caller.pending, callback) + ns.callbackCount++ + } ns.lock.Unlock() - // call field update subscription callbacks without holding the mutex + // call state update subscription callbacks without holding the mutex if callNow { - ns.processCallbacks(node, list) + callback() + ns.processCallbacks(caller) } return nil } @@ -929,7 +930,7 @@ func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node { // being in a given set specified by required and disabled state flags func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) { var count int64 - ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) { + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags, caller *Caller) { oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) if newMatch == oldMatch { diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go index 935062837a88..6139a01d27f7 100644 --- a/p2p/nodestate/nodestate_test.go +++ b/p2p/nodestate/nodestate_test.go @@ -70,15 +70,15 @@ func TestCallback(t *testing.T) { set0 := make(chan struct{}, 1) set1 := make(chan struct{}, 1) set2 := make(chan struct{}, 1) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) - ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) - ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set0 <- struct{}{} }) + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set1 <- struct{}{} }) + ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set2 <- struct{}{} }) ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetState(testNode(1), flags[1], Flags{}, time.Second) - ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) + ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) + ns.SetState(testNode(1), flags[1], Flags{}, time.Second, nil) + ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second, nil) for i := 0; i < 3; i++ { select { @@ -104,11 +104,11 @@ func TestPersistentFlags(t *testing.T) { ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved - ns.SetState(testNode(2), flags[1], Flags{}, 0) - ns.SetState(testNode(3), flags[2], Flags{}, 0) - ns.SetState(testNode(4), flags[3], Flags{}, 0) - ns.SetState(testNode(5), flags[0], Flags{}, 0) + ns.SetState(testNode(1), flags[0], Flags{}, time.Second, nil) // state with timeout should not be saved + ns.SetState(testNode(2), flags[1], Flags{}, 0, nil) + ns.SetState(testNode(3), flags[2], Flags{}, 0, nil) + ns.SetState(testNode(4), flags[3], Flags{}, 0, nil) + ns.SetState(testNode(5), flags[0], Flags{}, 0, nil) ns.Persist(testNode(5)) select { case <-saveNode: @@ -145,19 +145,19 @@ func TestSetField(t *testing.T) { ns.Start() // Set field before setting state - ns.SetField(testNode(1), fields[0], "hello world") + ns.SetField(testNode(1), fields[0], "hello world", nil) field := ns.GetField(testNode(1), fields[0]) if field != nil { t.Fatalf("Field shouldn't be set before setting states") } // Set field after setting state - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetField(testNode(1), fields[0], "hello world") + ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) + ns.SetField(testNode(1), fields[0], "hello world", nil) field = ns.GetField(testNode(1), fields[0]) if field == nil { t.Fatalf("Field should be set after setting states") } - if err := ns.SetField(testNode(1), fields[0], 123); err == nil { + if err := ns.SetField(testNode(1), fields[0], 123, nil); err == nil { t.Fatalf("Invalid field should be rejected") } // Dirty node should be written back @@ -177,10 +177,10 @@ func TestUnsetField(t *testing.T) { ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) - ns.SetField(testNode(1), fields[0], "hello world") + ns.SetState(testNode(1), flags[0], Flags{}, time.Second, nil) + ns.SetField(testNode(1), fields[0], "hello world", nil) - ns.SetState(testNode(1), Flags{}, flags[0], 0) + ns.SetState(testNode(1), Flags{}, flags[0], 0, nil) if field := ns.GetField(testNode(1), fields[0]); field != nil { t.Fatalf("Field should be unset") } @@ -194,7 +194,7 @@ func TestSetState(t *testing.T) { type change struct{ old, new Flags } set := make(chan change, 1) - ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { + ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags, caller *Caller) { set <- change{ old: oldState, new: newState, @@ -224,25 +224,25 @@ func TestSetState(t *testing.T) { return } } - ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) check(Flags{}, flags[0], true) - ns.SetState(testNode(1), flags[1], Flags{}, 0) + ns.SetState(testNode(1), flags[1], Flags{}, 0, nil) check(flags[0], flags[0].Or(flags[1]), true) - ns.SetState(testNode(1), flags[2], Flags{}, 0) + ns.SetState(testNode(1), flags[2], Flags{}, 0, nil) check(Flags{}, Flags{}, false) - ns.SetState(testNode(1), Flags{}, flags[0], 0) + ns.SetState(testNode(1), Flags{}, flags[0], 0, nil) check(flags[0].Or(flags[1]), flags[1], true) - ns.SetState(testNode(1), Flags{}, flags[1], 0) + ns.SetState(testNode(1), Flags{}, flags[1], 0, nil) check(flags[1], Flags{}, true) - ns.SetState(testNode(1), Flags{}, flags[2], 0) + ns.SetState(testNode(1), Flags{}, flags[2], 0, nil) check(Flags{}, Flags{}, false) - ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) + ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second, nil) check(Flags{}, flags[0].Or(flags[1]), true) clock.Run(time.Second) check(flags[0].Or(flags[1]), Flags{}, true) @@ -282,9 +282,9 @@ func TestPersistentFields(t *testing.T) { ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetField(testNode(1), fields[0], uint64(100)) - ns.SetField(testNode(1), fields[1], "hello world") + ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) + ns.SetField(testNode(1), fields[0], uint64(100), nil) + ns.SetField(testNode(1), fields[1], "hello world", nil) ns.Stop() ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) @@ -317,7 +317,7 @@ func TestFieldSub(t *testing.T) { lastState Flags lastOldValue, lastNewValue interface{} ) - ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { + ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}, caller *Caller) { lastState, lastOldValue, lastNewValue = state, oldValue, newValue }) check := func(state Flags, oldValue, newValue interface{}) { @@ -326,19 +326,19 @@ func TestFieldSub(t *testing.T) { } } ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetField(testNode(1), fields[0], uint64(100)) + ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) + ns.SetField(testNode(1), fields[0], uint64(100), nil) check(flags[0], nil, uint64(100)) ns.Stop() check(s.OfflineFlag(), uint64(100), nil) ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { + ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}, caller *Caller) { lastState, lastOldValue, lastNewValue = state, oldValue, newValue }) ns2.Start() check(s.OfflineFlag(), nil, uint64(100)) - ns2.SetState(testNode(1), Flags{}, flags[0], 0) + ns2.SetState(testNode(1), Flags{}, flags[0], 0, nil) check(Flags{}, uint64(100), nil) ns2.Stop() } @@ -351,7 +351,7 @@ func TestDuplicatedFlags(t *testing.T) { type change struct{ old, new Flags } set := make(chan change, 1) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set <- change{oldState, newState} }) @@ -379,9 +379,9 @@ func TestDuplicatedFlags(t *testing.T) { return } } - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) + ns.SetState(testNode(1), flags[0], Flags{}, time.Second, nil) check(Flags{}, flags[0], true) - ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s + ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second, nil) // extend the timeout to 2s check(Flags{}, flags[0], false) clock.Run(2 * time.Second) @@ -394,19 +394,19 @@ func TestCallbackOrder(t *testing.T) { s, flags, _ := testSetup([]bool{false, false, false, false}, nil) ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags, caller *Caller) { if newState.Equals(flags[0]) { - ns.SetState(n, flags[1], Flags{}, 0) - ns.SetState(n, flags[2], Flags{}, 0) + ns.SetState(n, flags[1], Flags{}, 0, caller) + ns.SetState(n, flags[2], Flags{}, 0, caller) } }) - ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags, caller *Caller) { if newState.Equals(flags[1]) { - ns.SetState(n, flags[3], Flags{}, 0) + ns.SetState(n, flags[3], Flags{}, 0, caller) } }) lastState := Flags{} - ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) { + ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags, caller *Caller) { if !oldState.Equals(lastState) { t.Fatalf("Wrong callback order") } @@ -416,5 +416,5 @@ func TestCallbackOrder(t *testing.T) { ns.Start() defer ns.Stop() - ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) } From 901d2b499fbb703ceef75bf568db834a44b51ba4 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Fri, 21 Aug 2020 13:12:58 +0200 Subject: [PATCH 04/13] p2p/nodestate: reverse checks to improve readability --- p2p/nodestate/nodestate.go | 62 ++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 4a32c0f05ab4..44a3778a8004 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -668,19 +668,21 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}, caller) } } - if newState == 0 { - // call field subscriptions for discarded fields - for i, v := range node.fields { - if v != nil { - if debugPrints { - fmt.Println("discardField", n.ID(), ns.setup.fields[i].name, v) - } - f := ns.fields[i] - if len(f.subs) > 0 { - for _, cb := range f.subs { - cb(n, Flags{setup: ns.setup}, v, nil, caller) - } - } + if newState != 0 { + return + } + // call field subscriptions for discarded fields + for i, v := range node.fields { + if v == nil { + continue + } + if debugPrints { + fmt.Println("discardField", n.ID(), ns.setup.fields[i].name, v) + } + f := ns.fields[i] + if len(f.subs) > 0 { + for _, cb := range f.subs { + cb(n, Flags{setup: ns.setup}, v, nil, caller) } } } @@ -730,22 +732,24 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { for _, sub := range ns.stateSubs { offState := offlineState & sub.mask onState := cb.state & sub.mask - if offState != onState { - if start { - sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}, caller) - } else { - sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}, caller) - } + if offState == onState { + continue + } + if start { + sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}, caller) + } else { + sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}, caller) } } for i, f := range cb.fields { - if f != nil && ns.fields[i].subs != nil { - for _, fsub := range ns.fields[i].subs { - if start { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f, caller) - } else { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil, caller) - } + if f == nil || ns.fields[i].subs == nil { + continue + } + for _, fsub := range ns.fields[i].subs { + if start { + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f, caller) + } else { + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil, caller) } } } @@ -870,10 +874,8 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface state := node.state callback := func() { - if len(f.subs) > 0 { - for _, cb := range f.subs { - cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value, caller) - } + for _, cb := range f.subs { + cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value, caller) } } callNow := caller == nil From 6823173b4adcf467e2635f2d02daf4173d04d8e8 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Fri, 21 Aug 2020 13:28:06 +0200 Subject: [PATCH 05/13] p2p/nodestate: removed temporary debug prints --- p2p/nodestate/nodestate.go | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 44a3778a8004..bee26e807676 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -18,7 +18,6 @@ package nodestate import ( "errors" - "fmt" "reflect" "sync" "time" @@ -33,8 +32,6 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) -const debugPrints = false - type ( // NodeStateMachine connects different system components operating on subsets of // network nodes. Node states are represented by 64 bit vectors with each bit assigned @@ -633,9 +630,6 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, } oldState := node.state newState := (node.state & (^reset)) | set - if debugPrints { - fmt.Println("SetState", n.ID(), "old", Flags{oldState, ns.setup}, "new", Flags{newState, ns.setup}, "set", setFlags, "reset", resetFlags, "timeout", timeout) - } changed := oldState ^ newState node.state = newState @@ -676,9 +670,6 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, if v == nil { continue } - if debugPrints { - fmt.Println("discardField", n.ID(), ns.setup.fields[i].name, v) - } f := ns.fields[i] if len(f.subs) > 0 { for _, cb := range f.subs { @@ -788,9 +779,6 @@ func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time ns.removeTimeouts(node, mask) t := &nodeStateTimeout{mask: mask} t.timer = ns.clock.AfterFunc(timeout, func() { - if debugPrints { - fmt.Println("timeout", n.ID(), Flags{mask, ns.setup}) - } ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0, nil) }) node.timeouts = append(node.timeouts, t) @@ -860,9 +848,6 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface return errors.New("invalid field type") } oldValue := node.fields[fieldIndex] - if debugPrints { - fmt.Println("SetField", n.ID(), Flags{node.state, ns.setup}, field.setup.fields[field.index].name, oldValue, value) - } if value == oldValue { ns.lock.Unlock() return nil From ec1c5287a3aaedc848c9613e1ef91e11af162e28 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Sat, 22 Aug 2020 14:34:13 +0200 Subject: [PATCH 06/13] p2p/nodestate: simplified callbacks --- p2p/nodestate/nodestate.go | 149 ++++++++++++++++++-------------- p2p/nodestate/nodestate_test.go | 88 +++++++++---------- 2 files changed, 126 insertions(+), 111 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index bee26e807676..bc89a34039da 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -60,8 +60,9 @@ type ( dbNodeKey []byte nodes map[enode.ID]*nodeInfo offlineCallbackList []offlineCallback - callbackCount int - callbackWait *sync.Cond + opFlag bool // an operation has started + opWait *sync.Cond // signaled when the operation ends + pending []func() // pending callback list of the current operation // Registered state flags or fields. Modifications are allowed // only when the node state machine has not been started. @@ -118,19 +119,15 @@ type ( // of node flags with each bit assigned to a flag index (LSB represents flag 0). bitMask uint64 - Caller struct { - pending []func() - } - // StateCallback is a subscription callback which is called when one of the // state flags that is included in the subscription state mask is changed. // Note: oldState and newState are also masked with the subscription mask so only // the relevant bits are included. - StateCallback func(n *enode.Node, oldState, newState Flags, caller *Caller) + StateCallback func(n *enode.Node, oldState, newState Flags) // FieldCallback is a subscription callback which is called when the value of // a specific field is changed. - FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}, caller *Caller) + FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}) // nodeInfo contains node state, fields and state timeouts nodeInfo struct { @@ -325,7 +322,7 @@ func NewNodeStateMachine(db ethdb.KeyValueStore, dbKey []byte, clock mclock.Cloc nodes: make(map[enode.ID]*nodeInfo), fields: make([]*fieldInfo, len(setup.fields)), } - ns.callbackWait = sync.NewCond(&ns.lock) + ns.opWait = sync.NewCond(&ns.lock) stateNameMap := make(map[string]int) for index, flag := range setup.flags { if _, ok := stateNameMap[flag.name]; ok { @@ -422,8 +419,8 @@ func (ns *NodeStateMachine) Start() { // Stop stops the state machine and saves its state if a database was supplied func (ns *NodeStateMachine) Stop() { ns.lock.Lock() - if ns.callbackCount != 0 { - ns.callbackWait.Wait() + if ns.opFlag { + ns.opWait.Wait() } for _, node := range ns.nodes { fields := make([]interface{}, len(node.fields)) @@ -610,11 +607,26 @@ func (ns *NodeStateMachine) Persist(n *enode.Node) error { // It only returns after all subsequent immediate changes (including those changed by the // callbacks) have been processed. If a flag with a timeout is set again, the operation // removes or replaces the existing timeout. -func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration, caller *Caller) { +func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { ns.lock.Lock() + defer ns.lock.Unlock() + + ns.opStart() + ns.setState(n, setFlags, resetFlags, timeout) + ns.opEnd() +} + +func (ns *NodeStateMachine) SetStateSub(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.opCheck() + ns.setState(n, setFlags, resetFlags, timeout) +} + +func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { ns.checkStarted() if ns.stopped { - ns.lock.Unlock() return } @@ -622,7 +634,6 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, id, node := ns.updateEnode(n) if node == nil { if set == 0 { - ns.lock.Unlock() return } node = ns.newNode(n) @@ -643,7 +654,6 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, ns.addTimeout(n, set, timeout) } if newState == oldState { - ns.lock.Unlock() return } if newState == 0 { @@ -659,7 +669,7 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, callback := func() { for _, sub := range ns.stateSubs { if changed&sub.mask != 0 { - sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}, caller) + sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) } } if newState != 0 { @@ -673,50 +683,52 @@ func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, f := ns.fields[i] if len(f.subs) > 0 { for _, cb := range f.subs { - cb(n, Flags{setup: ns.setup}, v, nil, caller) + cb(n, Flags{setup: ns.setup}, v, nil) } } } } - callNow := caller == nil - if callNow { - caller = &Caller{} - } else { - caller.pending = append(caller.pending, callback) - ns.callbackCount++ + ns.pending = append(ns.pending, callback) +} + +func (ns *NodeStateMachine) opCheck() { + if !ns.opFlag { + panic("Operation has not started") } - ns.lock.Unlock() - // call state update subscription callbacks without holding the mutex - if callNow { - callback() - ns.processCallbacks(caller) +} + +func (ns *NodeStateMachine) opStart() { + if ns.opFlag { + ns.opWait.Wait() } + ns.opFlag = true } // processCallbacks runs pending callbacks of a given node in a guaranteed correct order. // Callbacks resulting from a state/field change performed in a previous callback are always // put at the end of the pending list and therefore processed after all callbacks resulting // from the previous state/field change. -func (ns *NodeStateMachine) processCallbacks(caller *Caller) { - for len(caller.pending) != 0 { - list := caller.pending +func (ns *NodeStateMachine) opEnd() { + for len(ns.pending) != 0 { + list := ns.pending + ns.lock.Unlock() for _, cb := range list { cb() } ns.lock.Lock() - caller.pending = caller.pending[len(list):] - ns.callbackCount -= len(list) - if ns.callbackCount == 0 { - ns.callbackWait.Signal() - } - ns.lock.Unlock() + ns.pending = ns.pending[len(list):] } + ns.pending = nil + ns.opFlag = false + ns.opWait.Signal() } // offlineCallbacks calls state update callbacks at startup or shutdown func (ns *NodeStateMachine) offlineCallbacks(start bool) { ns.lock.Lock() - caller := &Caller{} + defer ns.lock.Unlock() + + ns.opStart() for _, cb := range ns.offlineCallbackList { cb := cb callback := func() { @@ -727,9 +739,9 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { continue } if start { - sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}, caller) + sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) } else { - sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}, caller) + sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) } } for i, f := range cb.fields { @@ -738,19 +750,17 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { } for _, fsub := range ns.fields[i].subs { if start { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f, caller) + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) } else { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil, caller) + fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) } } } } - caller.pending = append(caller.pending, callback) - ns.callbackCount++ + ns.pending = append(ns.pending, callback) } ns.offlineCallbackList = nil - ns.lock.Unlock() - ns.processCallbacks(caller) + ns.opEnd() } // AddTimeout adds a node state timeout associated to the given state flag(s). @@ -779,7 +789,7 @@ func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time ns.removeTimeouts(node, mask) t := &nodeStateTimeout{mask: mask} t.timer = ns.clock.AfterFunc(timeout, func() { - ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0, nil) + ns.SetState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) }) node.timeouts = append(node.timeouts, t) if mask&ns.saveFlags != 0 { @@ -828,28 +838,41 @@ func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { } // SetField sets the given field of the given node -func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}, caller *Caller) error { +func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { ns.lock.Lock() + defer ns.lock.Unlock() + + ns.opStart() + err := ns.setField(n, field, value) + ns.opEnd() + return err +} + +func (ns *NodeStateMachine) SetFieldSub(n *enode.Node, field Field, value interface{}) error { + ns.lock.Lock() + defer ns.lock.Unlock() + + ns.opCheck() + return ns.setField(n, field, value) +} + +func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface{}) error { ns.checkStarted() if ns.stopped { - ns.lock.Unlock() return nil } _, node := ns.updateEnode(n) if node == nil { - ns.lock.Unlock() return nil } fieldIndex := ns.fieldIndex(field) f := ns.fields[fieldIndex] if value != nil && reflect.TypeOf(value) != f.ftype { log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) - ns.lock.Unlock() return errors.New("invalid field type") } oldValue := node.fields[fieldIndex] if value == oldValue { - ns.lock.Unlock() return nil } node.fields[fieldIndex] = value @@ -860,22 +883,10 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface state := node.state callback := func() { for _, cb := range f.subs { - cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value, caller) + cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) } } - callNow := caller == nil - if callNow { - caller = &Caller{} - } else { - caller.pending = append(caller.pending, callback) - ns.callbackCount++ - } - ns.lock.Unlock() - // call state update subscription callbacks without holding the mutex - if callNow { - callback() - ns.processCallbacks(caller) - } + ns.pending = append(ns.pending, callback) return nil } @@ -884,6 +895,7 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { ns.lock.Lock() ns.checkStarted() + ns.opStart() type callback struct { node *enode.Node state bitMask @@ -899,6 +911,9 @@ func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n for _, c := range callbacks { cb(c.node, Flags{mask: c.state, setup: ns.setup}) } + ns.lock.Lock() + ns.opEnd() + ns.lock.Unlock() } // GetNode returns the enode currently associated with the given ID @@ -917,7 +932,7 @@ func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node { // being in a given set specified by required and disabled state flags func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) { var count int64 - ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags, caller *Caller) { + ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) { oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) if newMatch == oldMatch { diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go index 6139a01d27f7..37fc42359c31 100644 --- a/p2p/nodestate/nodestate_test.go +++ b/p2p/nodestate/nodestate_test.go @@ -70,15 +70,15 @@ func TestCallback(t *testing.T) { set0 := make(chan struct{}, 1) set1 := make(chan struct{}, 1) set2 := make(chan struct{}, 1) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set0 <- struct{}{} }) - ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set1 <- struct{}{} }) - ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags, caller *Caller) { set2 <- struct{}{} }) + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) + ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) - ns.SetState(testNode(1), flags[1], Flags{}, time.Second, nil) - ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second, nil) + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetState(testNode(1), flags[1], Flags{}, time.Second) + ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) for i := 0; i < 3; i++ { select { @@ -104,11 +104,11 @@ func TestPersistentFlags(t *testing.T) { ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, time.Second, nil) // state with timeout should not be saved - ns.SetState(testNode(2), flags[1], Flags{}, 0, nil) - ns.SetState(testNode(3), flags[2], Flags{}, 0, nil) - ns.SetState(testNode(4), flags[3], Flags{}, 0, nil) - ns.SetState(testNode(5), flags[0], Flags{}, 0, nil) + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved + ns.SetState(testNode(2), flags[1], Flags{}, 0) + ns.SetState(testNode(3), flags[2], Flags{}, 0) + ns.SetState(testNode(4), flags[3], Flags{}, 0) + ns.SetState(testNode(5), flags[0], Flags{}, 0) ns.Persist(testNode(5)) select { case <-saveNode: @@ -145,19 +145,19 @@ func TestSetField(t *testing.T) { ns.Start() // Set field before setting state - ns.SetField(testNode(1), fields[0], "hello world", nil) + ns.SetField(testNode(1), fields[0], "hello world") field := ns.GetField(testNode(1), fields[0]) if field != nil { t.Fatalf("Field shouldn't be set before setting states") } // Set field after setting state - ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) - ns.SetField(testNode(1), fields[0], "hello world", nil) + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], "hello world") field = ns.GetField(testNode(1), fields[0]) if field == nil { t.Fatalf("Field should be set after setting states") } - if err := ns.SetField(testNode(1), fields[0], 123, nil); err == nil { + if err := ns.SetField(testNode(1), fields[0], 123); err == nil { t.Fatalf("Invalid field should be rejected") } // Dirty node should be written back @@ -177,10 +177,10 @@ func TestUnsetField(t *testing.T) { ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, time.Second, nil) - ns.SetField(testNode(1), fields[0], "hello world", nil) + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) + ns.SetField(testNode(1), fields[0], "hello world") - ns.SetState(testNode(1), Flags{}, flags[0], 0, nil) + ns.SetState(testNode(1), Flags{}, flags[0], 0) if field := ns.GetField(testNode(1), fields[0]); field != nil { t.Fatalf("Field should be unset") } @@ -194,7 +194,7 @@ func TestSetState(t *testing.T) { type change struct{ old, new Flags } set := make(chan change, 1) - ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags, caller *Caller) { + ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { set <- change{ old: oldState, new: newState, @@ -224,25 +224,25 @@ func TestSetState(t *testing.T) { return } } - ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) + ns.SetState(testNode(1), flags[0], Flags{}, 0) check(Flags{}, flags[0], true) - ns.SetState(testNode(1), flags[1], Flags{}, 0, nil) + ns.SetState(testNode(1), flags[1], Flags{}, 0) check(flags[0], flags[0].Or(flags[1]), true) - ns.SetState(testNode(1), flags[2], Flags{}, 0, nil) + ns.SetState(testNode(1), flags[2], Flags{}, 0) check(Flags{}, Flags{}, false) - ns.SetState(testNode(1), Flags{}, flags[0], 0, nil) + ns.SetState(testNode(1), Flags{}, flags[0], 0) check(flags[0].Or(flags[1]), flags[1], true) - ns.SetState(testNode(1), Flags{}, flags[1], 0, nil) + ns.SetState(testNode(1), Flags{}, flags[1], 0) check(flags[1], Flags{}, true) - ns.SetState(testNode(1), Flags{}, flags[2], 0, nil) + ns.SetState(testNode(1), Flags{}, flags[2], 0) check(Flags{}, Flags{}, false) - ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second, nil) + ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) check(Flags{}, flags[0].Or(flags[1]), true) clock.Run(time.Second) check(flags[0].Or(flags[1]), Flags{}, true) @@ -282,9 +282,9 @@ func TestPersistentFields(t *testing.T) { ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) - ns.SetField(testNode(1), fields[0], uint64(100), nil) - ns.SetField(testNode(1), fields[1], "hello world", nil) + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], uint64(100)) + ns.SetField(testNode(1), fields[1], "hello world") ns.Stop() ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) @@ -317,7 +317,7 @@ func TestFieldSub(t *testing.T) { lastState Flags lastOldValue, lastNewValue interface{} ) - ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}, caller *Caller) { + ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { lastState, lastOldValue, lastNewValue = state, oldValue, newValue }) check := func(state Flags, oldValue, newValue interface{}) { @@ -326,19 +326,19 @@ func TestFieldSub(t *testing.T) { } } ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) - ns.SetField(testNode(1), fields[0], uint64(100), nil) + ns.SetState(testNode(1), flags[0], Flags{}, 0) + ns.SetField(testNode(1), fields[0], uint64(100)) check(flags[0], nil, uint64(100)) ns.Stop() check(s.OfflineFlag(), uint64(100), nil) ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}, caller *Caller) { + ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { lastState, lastOldValue, lastNewValue = state, oldValue, newValue }) ns2.Start() check(s.OfflineFlag(), nil, uint64(100)) - ns2.SetState(testNode(1), Flags{}, flags[0], 0, nil) + ns2.SetState(testNode(1), Flags{}, flags[0], 0) check(Flags{}, uint64(100), nil) ns2.Stop() } @@ -351,7 +351,7 @@ func TestDuplicatedFlags(t *testing.T) { type change struct{ old, new Flags } set := make(chan change, 1) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags, caller *Caller) { + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set <- change{oldState, newState} }) @@ -379,9 +379,9 @@ func TestDuplicatedFlags(t *testing.T) { return } } - ns.SetState(testNode(1), flags[0], Flags{}, time.Second, nil) + ns.SetState(testNode(1), flags[0], Flags{}, time.Second) check(Flags{}, flags[0], true) - ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second, nil) // extend the timeout to 2s + ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s check(Flags{}, flags[0], false) clock.Run(2 * time.Second) @@ -394,19 +394,19 @@ func TestCallbackOrder(t *testing.T) { s, flags, _ := testSetup([]bool{false, false, false, false}, nil) ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags, caller *Caller) { + ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { if newState.Equals(flags[0]) { - ns.SetState(n, flags[1], Flags{}, 0, caller) - ns.SetState(n, flags[2], Flags{}, 0, caller) + ns.SetStateSub(n, flags[1], Flags{}, 0) + ns.SetStateSub(n, flags[2], Flags{}, 0) } }) - ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags, caller *Caller) { + ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { if newState.Equals(flags[1]) { - ns.SetState(n, flags[3], Flags{}, 0, caller) + ns.SetStateSub(n, flags[3], Flags{}, 0) } }) lastState := Flags{} - ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags, caller *Caller) { + ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) { if !oldState.Equals(lastState) { t.Fatalf("Wrong callback order") } @@ -416,5 +416,5 @@ func TestCallbackOrder(t *testing.T) { ns.Start() defer ns.Stop() - ns.SetState(testNode(1), flags[0], Flags{}, 0, nil) + ns.SetState(testNode(1), flags[0], Flags{}, 0) } From 9d8f9938592511ecc956deaed2ace6fbbd7428e4 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Mon, 24 Aug 2020 14:33:22 +0200 Subject: [PATCH 07/13] les: server pool working with new NodeStateMachine --- les/lespay/client/fillset_test.go | 2 +- les/serverpool.go | 52 ++++++++++++++++++------------- p2p/nodestate/nodestate.go | 12 ++++++- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/les/lespay/client/fillset_test.go b/les/lespay/client/fillset_test.go index 58240682c60d..3a252e7539c3 100644 --- a/les/lespay/client/fillset_test.go +++ b/les/lespay/client/fillset_test.go @@ -102,7 +102,7 @@ func TestFillSet(t *testing.T) { expNotWaiting() // remove all previosly set flags ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - ns.SetState(node, nodestate.Flags{}, sfTest1, 0) + ns.SetStateSub(node, nodestate.Flags{}, sfTest1, 0) }) // now expect FillSet to fill the set up again with 10 new nodes expWaiting(10, true) diff --git a/les/serverpool.go b/les/serverpool.go index aff774324132..c3ac033ac2ea 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -166,7 +166,7 @@ func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *lpc.ValueTracker, d if oldState.Equals(sfWaitDialTimeout) && newState.IsEmpty() { // dial timeout, no connection s.setRedialWait(n, dialCost, dialWaitStep) - s.ns.SetState(n, nodestate.Flags{}, sfDialing, 0) + s.ns.SetStateSub(n, nodestate.Flags{}, sfDialing, 0) } }) @@ -193,10 +193,10 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod if rand.Intn(maxQueryFails*2) < int(fails) { // skip pre-negotiation with increasing chance, max 50% // this ensures that the client can operate even if UDP is not working at all - s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) + s.ns.SetStateSub(n, sfCanDial, nodestate.Flags{}, time.Second*10) // set canDial before resetting queried so that FillSet will not read more // candidates unnecessarily - s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) + s.ns.SetStateSub(n, nodestate.Flags{}, sfQueried, 0) return } go func() { @@ -206,12 +206,15 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod } else { atomic.StoreUint32(&s.queryFails, 0) } - if q == 1 { - s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10) - } else { - s.setRedialWait(n, queryCost, queryWaitStep) - } - s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0) + s.ns.Operation(func() { + // we are no longer running in the operation that the callback belongs to, start a new one because of setRedialWait + if q == 1 { + s.ns.SetStateSub(n, sfCanDial, nodestate.Flags{}, time.Second*10) + } else { + s.setRedialWait(n, queryCost, queryWaitStep) + } + s.ns.SetStateSub(n, nodestate.Flags{}, sfQueried, 0) + }) }() } }) @@ -250,7 +253,7 @@ func (s *serverPool) start() { // waiting time then the system clock was probably adjusted wait = lastWait } - s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) + s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) } }) } @@ -279,9 +282,11 @@ func (s *serverPool) registerPeer(p *serverPeer) { // unregisterPeer implements serverPeerSubscriber func (s *serverPool) unregisterPeer(p *serverPeer) { - s.setRedialWait(p.Node(), dialCost, dialWaitStep) - s.ns.SetState(p.Node(), nodestate.Flags{}, sfConnected, 0) - s.ns.SetField(p.Node(), sfiConnectedStats, nil) + s.ns.Operation(func() { + s.setRedialWait(p.Node(), dialCost, dialWaitStep) + s.ns.SetStateSub(p.Node(), nodestate.Flags{}, sfConnected, 0) + s.ns.SetFieldSub(p.Node(), sfiConnectedStats, nil) + }) s.vt.Unregister(p.ID()) p.setValueTracker(nil, nil) } @@ -380,14 +385,15 @@ func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue fl // updateWeight calculates the node weight and updates the nodeWeight field and the // hasValue flag. It also saves the node state if necessary. +// Note: this function should run inside a NodeStateMachine operation func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) { weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost)) if weight >= nodeWeightThreshold { - s.ns.SetState(node, sfHasValue, nodestate.Flags{}, 0) - s.ns.SetField(node, sfiNodeWeight, weight) + s.ns.SetStateSub(node, sfHasValue, nodestate.Flags{}, 0) + s.ns.SetFieldSub(node, sfiNodeWeight, weight) } else { - s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) - s.ns.SetField(node, sfiNodeWeight, nil) + s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0) + s.ns.SetFieldSub(node, sfiNodeWeight, nil) } s.ns.Persist(node) // saved if node history or hasValue changed } @@ -400,6 +406,7 @@ func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDia // a significant amount of service value again its waiting time is quickly reduced or reset // to the minimum. // Note: node weight is also recalculated and updated by this function. +// Note 2: this function should run inside a NodeStateMachine operation func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) { n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) sessionValue, totalValue := s.serviceValue(node) @@ -450,21 +457,22 @@ func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep if wait < waitThreshold { n.redialWaitStart = unixTime n.redialWaitEnd = unixTime + int64(nextTimeout) - s.ns.SetField(node, sfiNodeHistory, n) - s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, wait) + s.ns.SetFieldSub(node, sfiNodeHistory, n) + s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, wait) s.updateWeight(node, totalValue, totalDialCost) } else { // discard known node statistics if waiting time is very long because the node // hasn't been responsive for a very long time - s.ns.SetField(node, sfiNodeHistory, nil) - s.ns.SetField(node, sfiNodeWeight, nil) - s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0) + s.ns.SetFieldSub(node, sfiNodeHistory, nil) + s.ns.SetFieldSub(node, sfiNodeWeight, nil) + s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0) } } // calculateWeight calculates and sets the node weight without altering the node history. // This function should be called during startup and shutdown only, otherwise setRedialWait // will keep the weights updated as the underlying statistics are adjusted. +// Note: this function should run inside a NodeStateMachine operation func (s *serverPool) calculateWeight(node *enode.Node) { n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory) _, totalValue := s.serviceValue(node) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index bc89a34039da..19bd3a9a4b53 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -698,7 +698,7 @@ func (ns *NodeStateMachine) opCheck() { } func (ns *NodeStateMachine) opStart() { - if ns.opFlag { + for ns.opFlag { ns.opWait.Wait() } ns.opFlag = true @@ -723,6 +723,16 @@ func (ns *NodeStateMachine) opEnd() { ns.opWait.Signal() } +func (ns *NodeStateMachine) Operation(fn func()) { + ns.lock.Lock() + ns.opStart() + ns.lock.Unlock() + fn() + ns.lock.Lock() + ns.opEnd() + ns.lock.Unlock() +} + // offlineCallbacks calls state update callbacks at startup or shutdown func (ns *NodeStateMachine) offlineCallbacks(start bool) { ns.lock.Lock() From b5f788bfc36712280bcef00a6e12b24ecaa70400 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Mon, 24 Aug 2020 15:48:33 +0200 Subject: [PATCH 08/13] p2p/nodestate: added/updated comments --- p2p/nodestate/nodestate.go | 103 +++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 44 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 19bd3a9a4b53..d3d4648e8582 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -33,25 +33,29 @@ import ( ) type ( - // NodeStateMachine connects different system components operating on subsets of - // network nodes. Node states are represented by 64 bit vectors with each bit assigned - // to a state flag. Each state flag has a descriptor structure and the mapping is - // created automatically. It is possible to subscribe to subsets of state flags and - // receive a callback if one of the nodes has a relevant state flag changed. - // Callbacks can also modify further flags of the same node or other nodes. State - // updates only return after all immediate effects throughout the system have happened - // (deadlocks should be avoided by design of the implemented state logic). The caller - // can also add timeouts assigned to a certain node and a subset of state flags. - // If the timeout elapses, the flags are reset. If all relevant flags are reset then - // the timer is dropped. State flags with no timeout are persisted in the database - // if the flag descriptor enables saving. If a node has no state flags set at any - // moment then it is discarded. - // - // Extra node fields can also be registered so system components can also store more - // complex state for each node that is relevant to them, without creating a custom - // peer set. Fields can be shared across multiple components if they all know the - // field ID. Subscription to fields is also possible. Persistent fields should have - // an encoder and a decoder function. + // NodeStateMachine implements a network node-related event subscription system. + // It can assign binary state flags and fields of arbitrary type to each node and allows + // subscriptions to flag/field changes which can also modify further flags and fields, + // potentially triggering further subscriptions. An operation includes an initial change + // and all resulting subsequent changes and always ends in a consistent global state. + // It is initiated by a "top level" SetState/SetField call that blocks (also blocking other + // top-level functions) until the operation is finished. Callbacks making further changes + // should use the non-blocking SetStateSub/SetFieldSub functions. The tree of events + // resulting from the initial changes is traversed in a breadth-first order, ensuring for + // each subscription callback that all other callbacks caused by the same change triggering + // the current callback are processed before anything is triggered by the changes made in the + // current callback. In practice this logic ensures that all subscriptions "see" events in + // the logical order, callbacks are never called concurrently and "back and forth" effects + // are also possible. The state machine design should ensure that infinite event cycles + // cannot happen. + // The caller can also add timeouts assigned to a certain node and a subset of state flags. + // If the timeout elapses, the flags are reset. If all relevant flags are reset then the timer + // is dropped. State flags with no timeout are persisted in the database if the flag + // descriptor enables saving. If a node has no state flags set at any moment then it is discarded. + // Note: in order to avoid mutex deadlocks the callbacks should never lock a mutex that + // might be locked when the top level SetState/SetField functions are called. If a function + // potentially performs state/field changes then it is recommended to mention this fact in the + // function description, along with whether it should run inside an operation callback. NodeStateMachine struct { started, stopped bool lock sync.Mutex @@ -62,7 +66,7 @@ type ( offlineCallbackList []offlineCallback opFlag bool // an operation has started opWait *sync.Cond // signaled when the operation ends - pending []func() // pending callback list of the current operation + opPending []func() // pending callback list of the current operation // Registered state flags or fields. Modifications are allowed // only when the node state machine has not been started. @@ -361,10 +365,12 @@ func (ns *NodeStateMachine) fieldIndex(field Field) int { } // SubscribeState adds a node state subscription. The callback is called while the state -// machine mutex is not held and it is allowed to make further state updates. All immediate -// changes throughout the system are processed in the same thread/goroutine. It is the -// responsibility of the implemented state logic to avoid deadlocks caused by the callbacks, -// infinite toggling of flags or hazardous/non-deterministic state changes. +// machine mutex is not held and it is allowed to make further state updates using the +// non-blocking SetStateSub/SetFieldSub functions. All callbacks of an operation are running +// from the thread/goroutine of the initial caller and parallel operations are not permitted. +// Therefore the callback is never called concurrently. It is the responsibility of the +// implemented state logic to avoid deadlocks and to reach a stable state in a finite amount +// of steps. // State subscriptions should be installed before loading the node database or making the // first state update. func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) { @@ -603,19 +609,19 @@ func (ns *NodeStateMachine) Persist(n *enode.Node) error { return nil } -// SetState updates the given node state flags and processes all resulting callbacks. -// It only returns after all subsequent immediate changes (including those changed by the -// callbacks) have been processed. If a flag with a timeout is set again, the operation -// removes or replaces the existing timeout. +// SetState updates the given node state flags and blocks until the operation is finished. +// If a flag with a timeout is set again, the operation removes or replaces the existing timeout. func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { ns.lock.Lock() defer ns.lock.Unlock() ns.opStart() ns.setState(n, setFlags, resetFlags, timeout) - ns.opEnd() + ns.opFinish() } +// SetStateSub updates the given node state flags without blocking (should be called +// from a subscription/operation callback). func (ns *NodeStateMachine) SetStateSub(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { ns.lock.Lock() defer ns.lock.Unlock() @@ -688,15 +694,17 @@ func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, } } } - ns.pending = append(ns.pending, callback) + ns.opPending = append(ns.opPending, callback) } +// opCheck checks whether an operation is active func (ns *NodeStateMachine) opCheck() { if !ns.opFlag { panic("Operation has not started") } } +// opStart waits until other operations are finished and starts a new one func (ns *NodeStateMachine) opStart() { for ns.opFlag { ns.opWait.Wait() @@ -704,32 +712,35 @@ func (ns *NodeStateMachine) opStart() { ns.opFlag = true } -// processCallbacks runs pending callbacks of a given node in a guaranteed correct order. +// opFinish finishes the current operation by running all pending callbacks. // Callbacks resulting from a state/field change performed in a previous callback are always // put at the end of the pending list and therefore processed after all callbacks resulting // from the previous state/field change. -func (ns *NodeStateMachine) opEnd() { - for len(ns.pending) != 0 { - list := ns.pending +func (ns *NodeStateMachine) opFinish() { + for len(ns.opPending) != 0 { + list := ns.opPending ns.lock.Unlock() for _, cb := range list { cb() } ns.lock.Lock() - ns.pending = ns.pending[len(list):] + ns.opPending = ns.opPending[len(list):] } - ns.pending = nil + ns.opPending = nil ns.opFlag = false ns.opWait.Signal() } +// Operation calls the given function as an operation callback. This allows the caller +// to start an operation with multiple initial changes. The same rules apply as for +// subscription callbacks. func (ns *NodeStateMachine) Operation(fn func()) { ns.lock.Lock() ns.opStart() ns.lock.Unlock() fn() ns.lock.Lock() - ns.opEnd() + ns.opFinish() ns.lock.Unlock() } @@ -767,10 +778,10 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { } } } - ns.pending = append(ns.pending, callback) + ns.opPending = append(ns.opPending, callback) } ns.offlineCallbackList = nil - ns.opEnd() + ns.opFinish() } // AddTimeout adds a node state timeout associated to the given state flag(s). @@ -832,7 +843,9 @@ func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) { } } -// GetField retrieves the given field of the given node +// GetField retrieves the given field of the given node. Note that when used in a +// subscription callback the result can be out of sync with the state change represented +// by the callback parameters so extra safety checks might be necessary. func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { ns.lock.Lock() defer ns.lock.Unlock() @@ -847,17 +860,19 @@ func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { return nil } -// SetField sets the given field of the given node +// SetField sets the given field of the given node and blocks until the operation is finished func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { ns.lock.Lock() defer ns.lock.Unlock() ns.opStart() err := ns.setField(n, field, value) - ns.opEnd() + ns.opFinish() return err } +// SetFieldSub sets the given field of the given node without blocking (should be called +// from a subscription/operation callback). func (ns *NodeStateMachine) SetFieldSub(n *enode.Node, field Field, value interface{}) error { ns.lock.Lock() defer ns.lock.Unlock() @@ -896,7 +911,7 @@ func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) } } - ns.pending = append(ns.pending, callback) + ns.opPending = append(ns.opPending, callback) return nil } @@ -922,7 +937,7 @@ func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n cb(c.node, Flags{mask: c.state, setup: ns.setup}) } ns.lock.Lock() - ns.opEnd() + ns.opFinish() ns.lock.Unlock() } From a632106ef2e21fec10719e73899e2946ad005f7a Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Mon, 24 Aug 2020 15:58:24 +0200 Subject: [PATCH 09/13] p2p/nodestate: reverted ForEach change --- les/lespay/client/fillset_test.go | 2 +- les/serverpool.go | 32 +++++++++++++++++-------------- p2p/nodestate/nodestate.go | 8 +++----- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/les/lespay/client/fillset_test.go b/les/lespay/client/fillset_test.go index 3a252e7539c3..58240682c60d 100644 --- a/les/lespay/client/fillset_test.go +++ b/les/lespay/client/fillset_test.go @@ -102,7 +102,7 @@ func TestFillSet(t *testing.T) { expNotWaiting() // remove all previosly set flags ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - ns.SetStateSub(node, nodestate.Flags{}, sfTest1, 0) + ns.SetState(node, nodestate.Flags{}, sfTest1, 0) }) // now expect FillSet to fill the set up again with 10 new nodes expWaiting(10, true) diff --git a/les/serverpool.go b/les/serverpool.go index c3ac033ac2ea..4045f9888208 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -243,18 +243,20 @@ func (s *serverPool) start() { } } unixTime := s.unixTime() - s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - s.calculateWeight(node) - if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime { - wait := n.redialWaitEnd - unixTime - lastWait := n.redialWaitEnd - n.redialWaitStart - if wait > lastWait { - // if the time until expiration is larger than the last suggested - // waiting time then the system clock was probably adjusted - wait = lastWait + s.ns.Operation(func() { + s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + s.calculateWeight(node) + if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime { + wait := n.redialWaitEnd - unixTime + lastWait := n.redialWaitEnd - n.redialWaitStart + if wait > lastWait { + // if the time until expiration is larger than the last suggested + // waiting time then the system clock was probably adjusted + wait = lastWait + } + s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) } - s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second) - } + }) }) } @@ -264,9 +266,11 @@ func (s *serverPool) stop() { if s.fillSet != nil { s.fillSet.Close() } - s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) { - // recalculate weight of connected nodes in order to update hasValue flag if necessary - s.calculateWeight(n) + s.ns.Operation(func() { + s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) { + // recalculate weight of connected nodes in order to update hasValue flag if necessary + s.calculateWeight(n) + }) }) s.ns.Stop() } diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index d3d4648e8582..ab51dead6b9c 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -916,11 +916,12 @@ func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface } // ForEach calls the callback for each node having all of the required and none of the -// disabled flags set +// disabled flags set. +// Note that this callback is not an operation callback but ForEach can be called from an +// Operation callback or Operation can also be called from a ForEach callback if necessary. func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { ns.lock.Lock() ns.checkStarted() - ns.opStart() type callback struct { node *enode.Node state bitMask @@ -936,9 +937,6 @@ func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n for _, c := range callbacks { cb(c.node, Flags{mask: c.state, setup: ns.setup}) } - ns.lock.Lock() - ns.opFinish() - ns.lock.Unlock() } // GetNode returns the enode currently associated with the given ID From d0ea52f914eb2446647525e9c79a53133e1543ba Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Wed, 2 Sep 2020 15:47:38 +0200 Subject: [PATCH 10/13] p2p/nodestate: keep nodes with zero state and non-nil fields --- p2p/nodestate/nodestate.go | 77 +++++++++++++++++---------------- p2p/nodestate/nodestate_test.go | 25 +++-------- 2 files changed, 46 insertions(+), 56 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index ab51dead6b9c..bf4a05b1d986 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -135,11 +135,12 @@ type ( // nodeInfo contains node state, fields and state timeouts nodeInfo struct { - node *enode.Node - state bitMask - timeouts []*nodeStateTimeout - fields []interface{} - db, dirty bool + node *enode.Node + state bitMask + timeouts []*nodeStateTimeout + fields []interface{} + fieldCount int + db, dirty bool } nodeInfoEnc struct { @@ -490,6 +491,7 @@ func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { if decode := ns.fields[i].decode; decode != nil { if field, err := decode(encField); err == nil { node.fields[i] = field + node.fieldCount++ } else { log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err) return @@ -518,15 +520,6 @@ func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { for _, t := range node.timeouts { storedState &= ^t.mask } - if storedState == 0 { - if node.db { - node.db = false - ns.deleteNode(id) - } - node.dirty = false - return nil - } - enc := nodeInfoEnc{ Enr: *node.node.Record(), Version: ns.setup.Version, @@ -550,6 +543,14 @@ func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { enc.Fields[i] = blob lastIndex = i } + if storedState == 0 && lastIndex == -1 { + if node.db { + node.db = false + ns.deleteNode(id) + } + node.dirty = false + return nil + } enc.Fields = enc.Fields[:lastIndex+1] data, err := rlp.EncodeToBytes(&enc) if err != nil { @@ -654,15 +655,14 @@ func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, // even they are not existent(it's noop). ns.removeTimeouts(node, set|reset) - // Register the timeout callback if the new state is not empty - // and timeout itself is required. - if timeout != 0 && newState != 0 { + // Register the timeout callback if required + if timeout != 0 && set != 0 { ns.addTimeout(n, set, timeout) } if newState == oldState { return } - if newState == 0 { + if newState == 0 && node.fieldCount == 0 { delete(ns.nodes, id) if node.db { ns.deleteNode(id) @@ -678,21 +678,6 @@ func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) } } - if newState != 0 { - return - } - // call field subscriptions for discarded fields - for i, v := range node.fields { - if v == nil { - continue - } - f := ns.fields[i] - if len(f.subs) > 0 { - for _, cb := range f.subs { - cb(n, Flags{setup: ns.setup}, v, nil) - } - } - } } ns.opPending = append(ns.opPending, callback) } @@ -886,9 +871,13 @@ func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface if ns.stopped { return nil } - _, node := ns.updateEnode(n) + id, node := ns.updateEnode(n) if node == nil { - return nil + if value == nil { + return nil + } + node = ns.newNode(n) + ns.nodes[id] = node } fieldIndex := ns.fieldIndex(field) f := ns.fields[fieldIndex] @@ -900,11 +889,23 @@ func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface if value == oldValue { return nil } + if oldValue != nil { + node.fieldCount-- + } + if value != nil { + node.fieldCount++ + } node.fields[fieldIndex] = value - if f.encode != nil { - node.dirty = true + if node.state == 0 && node.fieldCount == 0 { + delete(ns.nodes, id) + if node.db { + ns.deleteNode(id) + } + } else { + if f.encode != nil { + node.dirty = true + } } - state := node.state callback := func() { for _, cb := range f.subs { diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go index 37fc42359c31..5f99a3da74cb 100644 --- a/p2p/nodestate/nodestate_test.go +++ b/p2p/nodestate/nodestate_test.go @@ -147,8 +147,13 @@ func TestSetField(t *testing.T) { // Set field before setting state ns.SetField(testNode(1), fields[0], "hello world") field := ns.GetField(testNode(1), fields[0]) + if field == nil { + t.Fatalf("Field should be set before setting states") + } + ns.SetField(testNode(1), fields[0], nil) + field = ns.GetField(testNode(1), fields[0]) if field != nil { - t.Fatalf("Field shouldn't be set before setting states") + t.Fatalf("Field should be unset") } // Set field after setting state ns.SetState(testNode(1), flags[0], Flags{}, 0) @@ -169,23 +174,6 @@ func TestSetField(t *testing.T) { } } -func TestUnsetField(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")}) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - ns.Start() - - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) - ns.SetField(testNode(1), fields[0], "hello world") - - ns.SetState(testNode(1), Flags{}, flags[0], 0) - if field := ns.GetField(testNode(1), fields[0]); field != nil { - t.Fatalf("Field should be unset") - } -} - func TestSetState(t *testing.T) { mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} @@ -339,6 +327,7 @@ func TestFieldSub(t *testing.T) { ns2.Start() check(s.OfflineFlag(), nil, uint64(100)) ns2.SetState(testNode(1), Flags{}, flags[0], 0) + ns2.SetField(testNode(1), fields[0], nil) check(Flags{}, uint64(100), nil) ns2.Stop() } From 585b28b395da77db042c6094c5719f3917183a1b Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Wed, 2 Sep 2020 15:53:07 +0200 Subject: [PATCH 11/13] les: also remove node history field when removing hasValue flag --- les/serverpool.go | 1 + 1 file changed, 1 insertion(+) diff --git a/les/serverpool.go b/les/serverpool.go index 4045f9888208..9bfa0bd7259d 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -398,6 +398,7 @@ func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDia } else { s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0) s.ns.SetFieldSub(node, sfiNodeWeight, nil) + s.ns.SetFieldSub(node, sfiNodeHistory, nil) } s.ns.Persist(node) // saved if node history or hasValue changed } From 18dc3273db5faca9d236412768cf8433f57535b7 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Wed, 9 Sep 2020 13:08:20 +0200 Subject: [PATCH 12/13] p2p/nodestate: ensure correct shutdown --- p2p/nodestate/nodestate.go | 74 ++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index bf4a05b1d986..7c974b3df628 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -32,6 +32,11 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +var ( + ErrInvalidField = errors.New("invalid field type") + ErrClosed = errors.New("already closed") +) + type ( // NodeStateMachine implements a network node-related event subscription system. // It can assign binary state flags and fields of arbitrary type to each node and allows @@ -57,7 +62,7 @@ type ( // potentially performs state/field changes then it is recommended to mention this fact in the // function description, along with whether it should run inside an operation callback. NodeStateMachine struct { - started, stopped bool + started, closed bool lock sync.Mutex clock mclock.Clock db ethdb.KeyValueStore @@ -419,29 +424,33 @@ func (ns *NodeStateMachine) Start() { if ns.db != nil { ns.loadFromDb() } - ns.lock.Unlock() + + ns.opStart() ns.offlineCallbacks(true) + ns.opFinish() + ns.lock.Unlock() } // Stop stops the state machine and saves its state if a database was supplied func (ns *NodeStateMachine) Stop() { ns.lock.Lock() - if ns.opFlag { - ns.opWait.Wait() + defer ns.lock.Unlock() + + ns.checkStarted() + if !ns.opStart() { + panic("already closed") } for _, node := range ns.nodes { fields := make([]interface{}, len(node.fields)) copy(fields, node.fields) ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) } - ns.stopped = true if ns.db != nil { ns.saveToDb() - ns.lock.Unlock() - } else { - ns.lock.Unlock() } ns.offlineCallbacks(false) + ns.closed = true + ns.opFinish() } // loadFromDb loads persisted node states from the database @@ -612,13 +621,16 @@ func (ns *NodeStateMachine) Persist(n *enode.Node) error { // SetState updates the given node state flags and blocks until the operation is finished. // If a flag with a timeout is set again, the operation removes or replaces the existing timeout. -func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { +func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) error { ns.lock.Lock() defer ns.lock.Unlock() - ns.opStart() + if !ns.opStart() { + return ErrClosed + } ns.setState(n, setFlags, resetFlags, timeout) ns.opFinish() + return nil } // SetStateSub updates the given node state flags without blocking (should be called @@ -633,10 +645,6 @@ func (ns *NodeStateMachine) SetStateSub(n *enode.Node, setFlags, resetFlags Flag func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { ns.checkStarted() - if ns.stopped { - return - } - set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags) id, node := ns.updateEnode(n) if node == nil { @@ -680,6 +688,7 @@ func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, } } ns.opPending = append(ns.opPending, callback) + return } // opCheck checks whether an operation is active @@ -690,11 +699,15 @@ func (ns *NodeStateMachine) opCheck() { } // opStart waits until other operations are finished and starts a new one -func (ns *NodeStateMachine) opStart() { +func (ns *NodeStateMachine) opStart() bool { for ns.opFlag { ns.opWait.Wait() } + if ns.closed { + return false + } ns.opFlag = true + return true } // opFinish finishes the current operation by running all pending callbacks. @@ -719,22 +732,22 @@ func (ns *NodeStateMachine) opFinish() { // Operation calls the given function as an operation callback. This allows the caller // to start an operation with multiple initial changes. The same rules apply as for // subscription callbacks. -func (ns *NodeStateMachine) Operation(fn func()) { +func (ns *NodeStateMachine) Operation(fn func()) error { ns.lock.Lock() - ns.opStart() + started := ns.opStart() ns.lock.Unlock() + if !started { + return ErrClosed + } fn() ns.lock.Lock() ns.opFinish() ns.lock.Unlock() + return nil } // offlineCallbacks calls state update callbacks at startup or shutdown func (ns *NodeStateMachine) offlineCallbacks(start bool) { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.opStart() for _, cb := range ns.offlineCallbackList { cb := cb callback := func() { @@ -766,20 +779,20 @@ func (ns *NodeStateMachine) offlineCallbacks(start bool) { ns.opPending = append(ns.opPending, callback) } ns.offlineCallbackList = nil - ns.opFinish() } // AddTimeout adds a node state timeout associated to the given state flag(s). // After the specified time interval, the relevant states will be reset. -func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) { +func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) error { ns.lock.Lock() defer ns.lock.Unlock() ns.checkStarted() - if ns.stopped { - return + if ns.closed { + return ErrClosed } ns.addTimeout(n, ns.stateMask(flags), timeout) + return nil } // addTimeout adds a node state timeout associated to the given state flag(s). @@ -836,7 +849,7 @@ func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { defer ns.lock.Unlock() ns.checkStarted() - if ns.stopped { + if ns.closed { return nil } if _, node := ns.updateEnode(n); node != nil { @@ -850,7 +863,9 @@ func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface ns.lock.Lock() defer ns.lock.Unlock() - ns.opStart() + if !ns.opStart() { + return ErrClosed + } err := ns.setField(n, field, value) ns.opFinish() return err @@ -868,9 +883,6 @@ func (ns *NodeStateMachine) SetFieldSub(n *enode.Node, field Field, value interf func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface{}) error { ns.checkStarted() - if ns.stopped { - return nil - } id, node := ns.updateEnode(n) if node == nil { if value == nil { @@ -883,7 +895,7 @@ func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface f := ns.fields[fieldIndex] if value != nil && reflect.TypeOf(value) != f.ftype { log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) - return errors.New("invalid field type") + return ErrInvalidField } oldValue := node.fields[fieldIndex] if value == oldValue { From a83c0ed621ac0e725e2bb1160d60a72d19702530 Mon Sep 17 00:00:00 2001 From: Zsolt Felfoldi Date: Fri, 11 Sep 2020 02:24:01 +0200 Subject: [PATCH 13/13] p2p/nodestate: fix linter error --- p2p/nodestate/nodestate.go | 1 - 1 file changed, 1 deletion(-) diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index 7c974b3df628..ab28b47a159f 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -688,7 +688,6 @@ func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, } } ns.opPending = append(ns.opPending, callback) - return } // opCheck checks whether an operation is active