diff --git a/p2p/dial.go b/p2p/dial.go index 55d11290897..6da7edf04df 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -267,7 +267,7 @@ loop: } case task := <-d.doneCh: - id := task.dest.ID() + id := task.dest().ID() delete(d.dialing, id) d.updateStaticPool(id) d.dialed++ @@ -439,7 +439,7 @@ func (d *dialScheduler) startStaticDials() { // updateStaticPool attempts to move the given static dial back into staticPool. func (d *dialScheduler) updateStaticPool(id enode.ID) { task, ok := d.static[id] - if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil { + if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest()) == nil { d.addToStaticPool(task) } } @@ -466,10 +466,11 @@ func (d *dialScheduler) removeFromStaticPool(idx int) { // startDial runs the given dial task in a separate goroutine. func (d *dialScheduler) startDial(task *dialTask) { - d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags) - hkey := string(task.dest.ID().Bytes()) + node := task.dest() + d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags) + hkey := string(node.ID().Bytes()) d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) - d.dialing[task.dest.ID()] = task + d.dialing[node.ID()] = task go func() { defer debug.LogPanic() task.run(d) @@ -483,13 +484,19 @@ type dialTask struct { flags connFlag // These fields are private to the task and should not be // accessed by dialScheduler while the task is running. - dest *enode.Node + destPtr atomic.Pointer[enode.Node] lastResolved mclock.AbsTime resolveDelay time.Duration } func newDialTask(dest *enode.Node, flags connFlag) *dialTask { - return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1} + t := &dialTask{flags: flags, staticPoolIndex: -1} + t.destPtr.Store(dest) + return t +} + +func (t *dialTask) dest() *enode.Node { + return t.destPtr.Load() } type dialError struct { @@ -501,19 +508,19 @@ func (t *dialTask) run(d *dialScheduler) { return } - err := t.dial(d, t.dest) + err := t.dial(d, t.dest()) if err != nil { // For static nodes, resolve one more time if dialing fails. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if t.resolve(d) { - t.dial(d, t.dest) //nolint:errcheck + t.dial(d, t.dest()) //nolint:errcheck } } } } func (t *dialTask) needResolve() bool { - return t.flags&staticDialedConn != 0 && t.dest.IP() == nil + return t.flags&staticDialedConn != 0 && t.dest().IP() == nil } // resolve attempts to find the current endpoint for the destination @@ -532,29 +539,29 @@ func (t *dialTask) resolve(d *dialScheduler) bool { if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay { return false } - resolved := d.resolver.Resolve(t.dest) + resolved := d.resolver.Resolve(t.dest()) t.lastResolved = d.clock.Now() if resolved == nil { t.resolveDelay *= 2 if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - d.log.Warn("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) + d.log.Warn("Resolving node failed", "id", t.dest().ID(), "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay - t.dest = resolved - d.log.Trace("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + t.destPtr.Store(resolved) + d.log.Trace("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()}) return true } // dial performs the actual connection attempt. func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { - fd, err := d.dialer.Dial(d.ctx, t.dest) + fd, err := d.dialer.Dial(d.ctx, dest) if err != nil { cleanErr := cleanupDialErr(err) - d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanErr) + d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanErr) d.mutex.Lock() d.errors[cleanErr.Error()] = d.errors[cleanErr.Error()] + 1 @@ -566,8 +573,9 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { } func (t *dialTask) String() string { - id := t.dest.ID() - return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) + node := t.dest() + id := node.ID() + return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP()) } func cleanupDialErr(err error) error {