diff --git a/peer/peer.go b/peer/peer.go index 17cb477645..3375092709 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -1737,27 +1737,42 @@ func (p *Peer) Disconnect() { close(p.quit) } -// handleRemoteVersionMsg is invoked when a version wire message is received -// from the remote peer. It will return an error if the remote peer's version -// is not compatible with ours. -func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { +// readRemoteVersionMsg waits for the next message to arrive from the remote +// peer. If the next message is not a version message or the version is not +// acceptable then return an error. +func (p *Peer) readRemoteVersionMsg() error { + // Read their version message. + remoteMsg, _, err := p.readMessage() + if err != nil { + return err + } + + // Notify and disconnect clients if the first message is not a version + // message. + msg, ok := remoteMsg.(*wire.MsgVersion) + if !ok { + reason := "a version message must precede all others" + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, + reason) + _ = p.writeMessage(rejectMsg) + return errors.New(reason) + } + // Detect self connections. if !allowSelfConns && sentNonces.Exists(msg.Nonce) { return errors.New("disconnecting peer connected to self") } - // Notify and disconnect clients that have a protocol version that is - // too old. - if msg.ProtocolVersion < int32(wire.InitialProcotolVersion) { - // Send a reject message indicating the protocol version is - // obsolete and wait for the message to be sent before - // disconnecting. - reason := fmt.Sprintf("protocol version must be %d or greater", - wire.InitialProcotolVersion) - rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, - reason) - return p.writeMessage(rejectMsg) - } + // Negotiate the protocol version and set the services to what the remote + // peer advertised. + p.flagsMtx.Lock() + p.advertisedProtoVer = uint32(msg.ProtocolVersion) + p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) + p.versionKnown = true + p.services = msg.Services + p.flagsMtx.Unlock() + log.Debugf("Negotiated protocol version %d for peer %s", + p.protocolVersion, p) // Updating a bunch of stats. p.statsMtx.Lock() @@ -1768,51 +1783,31 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix() p.statsMtx.Unlock() - // Negotiate the protocol version. + // Set the peer's ID and user agent. p.flagsMtx.Lock() - p.advertisedProtoVer = uint32(msg.ProtocolVersion) - p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) - p.versionKnown = true - log.Debugf("Negotiated protocol version %d for peer %s", - p.protocolVersion, p) - // Set the peer's ID. p.id = atomic.AddInt32(&nodeCount, 1) - // Set the supported services for the peer to what the remote peer - // advertised. - p.services = msg.Services - // Set the remote peer's user agent. p.userAgent = msg.UserAgent p.flagsMtx.Unlock() - return nil -} - -// readRemoteVersionMsg waits for the next message to arrive from the remote -// peer. If the next message is not a version message or the version is not -// acceptable then return an error. -func (p *Peer) readRemoteVersionMsg() error { - // Read their version message. - msg, _, err := p.readMessage() - if err != nil { - return err - } - - remoteVerMsg, ok := msg.(*wire.MsgVersion) - if !ok { - errStr := "A version message must precede all others" - log.Errorf(errStr) - rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectMalformed, - errStr) - return p.writeMessage(rejectMsg) + // Invoke the callback if specified. + if p.cfg.Listeners.OnVersion != nil { + p.cfg.Listeners.OnVersion(p, msg) } - if err := p.handleRemoteVersionMsg(remoteVerMsg); err != nil { - return err + // Notify and disconnect clients that have a protocol version that is + // too old. + if msg.ProtocolVersion < int32(wire.InitialProcotolVersion) { + // Send a reject message indicating the protocol version is + // obsolete and wait for the message to be sent before + // disconnecting. + reason := fmt.Sprintf("protocol version must be %d or greater", + wire.InitialProcotolVersion) + rejectMsg := wire.NewMsgReject(msg.Command(), wire.RejectObsolete, + reason) + _ = p.writeMessage(rejectMsg) + return errors.New(reason) } - if p.cfg.Listeners.OnVersion != nil { - p.cfg.Listeners.OnVersion(p, remoteVerMsg) - } return nil } @@ -1953,9 +1948,11 @@ func (p *Peer) start() error { select { case err := <-negotiateErr: if err != nil { + p.Disconnect() return err } case <-time.After(negotiateTimeout): + p.Disconnect() return errors.New("protocol negotiation timeout") } log.Debugf("Connected to %s", p.Addr()) diff --git a/server.go b/server.go index e24c0d4330..c99c430fbf 100644 --- a/server.go +++ b/server.go @@ -327,6 +327,12 @@ func (sp *serverPeer) addBanScore(persistent, transient uint32, reason string) { // to negotiate the protocol version details as well as kick start the // communications. func (sp *serverPeer) OnVersion(p *peer.Peer, msg *wire.MsgVersion) { + // Ignore peers that have a protcol version that is too old. The peer + // negotiation logic will disconnect it after this callback returns. + if msg.ProtocolVersion < int32(wire.InitialProcotolVersion) { + return + } + // Add the remote peer time as a sample for creating an offset against // the local clock to keep the network time in sync. sp.server.timeSource.AddTimeSample(p.Addr(), msg.Timestamp)