From 774139283b5d72cde9ff463c8b7c0d4c08fddde7 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Sat, 2 Jun 2018 22:46:48 -0500 Subject: [PATCH] peer: Rework version negotiation. This modifies the negotiation logic to ensure the callback has the opportunity to see the message before the peer is disconnected and improves the error handling when reading the remote version message. It also has the side effect of ensuring the protocol version is negotiated before sending reject messages with the exception of the first message not being a version message since negotiation is not possible in that case. This is being changed because it is useful for the server to see the message regardless in order to have the opportunity to things such as update the address manager and reject peers that don't have desired services. --- peer/peer.go | 101 +++++++++++++++++++++++++-------------------------- server.go | 6 +++ 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/peer/peer.go b/peer/peer.go index 17cb477645..3375092709 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -1737,27 +1737,42 @@ func (p *Peer) Disconnect() { close(p.quit) } -// handleRemoteVersionMsg is invoked when a version wire message is received -// from the remote peer. It will return an error if the remote peer's version -// is not compatible with ours. -func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { +// readRemoteVersionMsg waits for the next message to arrive from the remote +// peer. If the next message is not a version message or the version is not +// acceptable then return an error. +func (p *Peer) readRemoteVersionMsg() error { + // Read their version message. + remoteMsg, _, err := p.readMessage() + if err != nil { + return err + } + + // Notify and disconnect clients if the first message is not a version + // message. + msg, ok := remoteMsg.(*wire.MsgVersion) + if !ok { + reason := "a version message must precede all others" + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, + reason) + _ = p.writeMessage(rejectMsg) + return errors.New(reason) + } + // Detect self connections. if !allowSelfConns && sentNonces.Exists(msg.Nonce) { return errors.New("disconnecting peer connected to self") } - // Notify and disconnect clients that have a protocol version that is - // too old. - if msg.ProtocolVersion < int32(wire.InitialProcotolVersion) { - // Send a reject message indicating the protocol version is - // obsolete and wait for the message to be sent before - // disconnecting. - reason := fmt.Sprintf("protocol version must be %d or greater", - wire.InitialProcotolVersion) - rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, - reason) - return p.writeMessage(rejectMsg) - } + // Negotiate the protocol version and set the services to what the remote + // peer advertised. + p.flagsMtx.Lock() + p.advertisedProtoVer = uint32(msg.ProtocolVersion) + p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) + p.versionKnown = true + p.services = msg.Services + p.flagsMtx.Unlock() + log.Debugf("Negotiated protocol version %d for peer %s", + p.protocolVersion, p) // Updating a bunch of stats. p.statsMtx.Lock() @@ -1768,51 +1783,31 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix() p.statsMtx.Unlock() - // Negotiate the protocol version. + // Set the peer's ID and user agent. p.flagsMtx.Lock() - p.advertisedProtoVer = uint32(msg.ProtocolVersion) - p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) - p.versionKnown = true - log.Debugf("Negotiated protocol version %d for peer %s", - p.protocolVersion, p) - // Set the peer's ID. p.id = atomic.AddInt32(&nodeCount, 1) - // Set the supported services for the peer to what the remote peer - // advertised. - p.services = msg.Services - // Set the remote peer's user agent. p.userAgent = msg.UserAgent p.flagsMtx.Unlock() - return nil -} - -// readRemoteVersionMsg waits for the next message to arrive from the remote -// peer. If the next message is not a version message or the version is not -// acceptable then return an error. -func (p *Peer) readRemoteVersionMsg() error { - // Read their version message. - msg, _, err := p.readMessage() - if err != nil { - return err - } - - remoteVerMsg, ok := msg.(*wire.MsgVersion) - if !ok { - errStr := "A version message must precede all others" - log.Errorf(errStr) - rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, - errStr) - return p.writeMessage(rejectMsg) + // Invoke the callback if specified. + if p.cfg.Listeners.OnVersion != nil { + p.cfg.Listeners.OnVersion(p, msg) } - if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil { - return err + // Notify and disconnect clients that have a protocol version that is + // too old. + if msg.ProtocolVersion < int32(wire.InitialProcotolVersion) { + // Send a reject message indicating the protocol version is + // obsolete and wait for the message to be sent before + // disconnecting. + reason := fmt.Sprintf("protocol version must be %d or greater", + wire.InitialProcotolVersion) + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, + reason) + _ = p.writeMessage(rejectMsg) + return errors.New(reason) } - if p.cfg.Listeners.OnVersion != nil { - p.cfg.Listeners.OnVersion(p, remoteVerMsg) - } return nil } @@ -1953,9 +1948,11 @@ func (p *Peer) start() error { select { case err := <-negotiateErr: if err != nil { + p.Disconnect() return err } case <-time.After(negotiateTimeout): + p.Disconnect() return errors.New("protocol negotiation timeout") } log.Debugf("Connected to %s", p.Addr()) diff --git a/server.go b/server.go index e24c0d4330..c99c430fbf 100644 --- a/server.go +++ b/server.go @@ -327,6 +327,12 @@ func (sp *serverPeer) addBanScore(persistent, transient uint32, reason string) { // to negotiate the protocol version details as well as kick start the // communications. func (sp *serverPeer) OnVersion(p *peer.Peer, msg *wire.MsgVersion) { + // Ignore peers that have a protcol version that is too old. The peer + // negotiation logic will disconnect it after this callback returns. + if msg.ProtocolVersion < int32(wire.InitialProcotolVersion) { + return + } + // Add the remote peer time as a sample for creating an offset against // the local clock to keep the network time in sync. sp.server.timeSource.AddTimeSample(p.Addr(), msg.Timestamp)