diff --git a/components/mocks/mockNetwork.go b/components/mocks/mockNetwork.go index 7e1ab29126..c0b7724e07 100644 --- a/components/mocks/mockNetwork.go +++ b/components/mocks/mockNetwork.go @@ -91,6 +91,14 @@ func (network *MockNetwork) RegisterHandlers(dispatch []network.TaggedMessageHan func (network *MockNetwork) ClearHandlers() { } +// RegisterProcessors - empty implementation. +func (network *MockNetwork) RegisterProcessors(dispatch []network.TaggedMessageProcessor) { +} + +// ClearProcessors - empty implementation +func (network *MockNetwork) ClearProcessors() { +} + // RegisterHTTPHandler - empty implementation func (network *MockNetwork) RegisterHTTPHandler(path string, handler http.Handler) { } diff --git a/data/txHandler.go b/data/txHandler.go index 3d20e95acd..895b07269d 100644 --- a/data/txHandler.go +++ b/data/txHandler.go @@ -243,6 +243,21 @@ func (handler *TxHandler) Start() { handler.net.RegisterHandlers([]network.TaggedMessageHandler{ {Tag: protocol.TxnTag, MessageHandler: network.HandlerFunc(handler.processIncomingTxn)}, }) + + handler.net.RegisterProcessors([]network.TaggedMessageProcessor{ + { + Tag: protocol.TxnTag, + // create anonymous struct to hold the two functions and satisfy the network.MessageProcessor interface + MessageProcessor: struct { + network.ProcessorValidateFunc + network.ProcessorHandleFunc + }{ + network.ProcessorValidateFunc(handler.validateIncomingTxMessage), + network.ProcessorHandleFunc(handler.processIncomingTxMessage), + }, + }, + }) + handler.backlogWg.Add(2) go handler.backlogWorker() go handler.backlogGaugeThread() @@ -530,7 +545,7 @@ func (handler *TxHandler) deleteFromCaches(msgKey *crypto.Digest, canonicalKey * // dedupCanonical checks if the transaction group has been seen before after reencoding to canonical representation. // returns a key used for insertion if the group was not found. -func (handler *TxHandler) dedupCanonical(ntx int, unverifiedTxGroup []transactions.SignedTxn, consumed int) (key *crypto.Digest, isDup bool) { +func (handler *TxHandler) dedupCanonical(unverifiedTxGroup []transactions.SignedTxn, consumed int) (key *crypto.Digest, isDup bool) { // consider situations where someone want to censor transactions A // 1. Txn A is not part of a group => txn A with a valid signature is OK // Censorship attempts are: @@ -547,6 +562,7 @@ func (handler *TxHandler) dedupCanonical(ntx int, unverifiedTxGroup []transactio // - using individual txn from a group: {A, Z} could be poisoned by {A, B}, where B is invalid var d crypto.Digest + ntx := len(unverifiedTxGroup) if ntx == 1 { // a single transaction => cache/dedup canonical txn with its signature enc := unverifiedTxGroup[0].MarshalMsg(nil) @@ -574,49 +590,52 @@ func (handler *TxHandler) dedupCanonical(ntx int, unverifiedTxGroup []transactio return &d, false } -// processIncomingTxn decodes a transaction group from incoming message and enqueues into the back log for processing. -// The function also performs some input data pre-validation; -// - txn groups are cut to MaxTxGroupSize size -// - message are checked for duplicates -// - transactions are checked for duplicates - -func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) network.OutgoingMessage { +// incomingMsgDupErlCheck runs the duplicate and rate limiting checks on a raw incoming messages. +// Returns: +// - the key used for insertion if the message was not found in the cache +// - the capacity guard returned by the elastic rate limiter +// - a boolean indicating if the message was a duplicate or the sender is rate limited +func (handler *TxHandler) incomingMsgDupErlCheck(data []byte, sender network.DisconnectablePeer) (*crypto.Digest, *util.ErlCapacityGuard, bool) { var msgKey *crypto.Digest + var capguard *util.ErlCapacityGuard var isDup bool if handler.msgCache != nil { // check for duplicate messages // this helps against relaying duplicates - if msgKey, isDup = handler.msgCache.CheckAndPut(rawmsg.Data); isDup { + if msgKey, isDup = handler.msgCache.CheckAndPut(data); isDup { transactionMessagesDupRawMsg.Inc(nil) - return network.OutgoingMessage{Action: network.Ignore} + return msgKey, capguard, true } } - unverifiedTxGroup := make([]transactions.SignedTxn, 1) - dec := protocol.NewMsgpDecoderBytes(rawmsg.Data) - ntx := 0 - consumed := 0 - var err error - var capguard *util.ErlCapacityGuard if handler.erl != nil { congestedERL := float64(cap(handler.backlogQueue))*handler.backlogCongestionThreshold < float64(len(handler.backlogQueue)) // consume a capacity unit // if the elastic rate limiter cannot vend a capacity, the error it returns // is sufficient to indicate that we should enable Congestion Control, because // an issue in vending capacity indicates the underlying resource (TXBacklog) is full - capguard, err = handler.erl.ConsumeCapacity(rawmsg.Sender.(util.ErlClient)) + capguard, err = handler.erl.ConsumeCapacity(sender.(util.ErlClient)) if err != nil { handler.erl.EnableCongestionControl() // if there is no capacity, it is the same as if we failed to put the item onto the backlog, so report such transactionMessagesDroppedFromBacklog.Inc(nil) - return network.OutgoingMessage{Action: network.Ignore} + return msgKey, capguard, true } // if the backlog Queue has 50% of its buffer back, turn congestion control off if !congestedERL { handler.erl.DisableCongestionControl() } } + return msgKey, capguard, false +} + +// decodeMsg decodes TX message buffer into transactions.SignedTxn, +// and returns number of bytes consumed from the buffer and a boolean indicating if the message was invalid. +func decodeMsg(data []byte) (unverifiedTxGroup []transactions.SignedTxn, consumed int, invalid bool) { + unverifiedTxGroup = make([]transactions.SignedTxn, 1) + dec := protocol.NewMsgpDecoderBytes(data) + ntx := 0 for { if len(unverifiedTxGroup) == ntx { @@ -630,7 +649,7 @@ func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) net break } logging.Base().Warnf("Received a non-decodable txn: %v", err) - return network.OutgoingMessage{Action: network.Disconnect} + return nil, 0, true } consumed = dec.Consumed() ntx++ @@ -639,13 +658,13 @@ func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) net if dec.Remaining() > 0 { // if something else left in the buffer - this is an error, drop transactionMessageTxGroupExcessive.Inc(nil) - return network.OutgoingMessage{Action: network.Disconnect} + return nil, 0, true } } } if ntx == 0 { logging.Base().Warnf("Received empty tx group") - return network.OutgoingMessage{Action: network.Disconnect} + return nil, 0, true } unverifiedTxGroup = unverifiedTxGroup[:ntx] @@ -654,22 +673,57 @@ func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) net transactionMessageTxGroupFull.Inc(nil) } + return unverifiedTxGroup, consumed, false +} + +// incomingTxGroupDupRateLimit checks +// - if the incoming transaction group has been seen before after reencoding to canonical representation, and +// - if the sender is rate limited by the per-application rate limiter. +func (handler *TxHandler) incomingTxGroupDupRateLimit(unverifiedTxGroup []transactions.SignedTxn, encodedExpectedSize int, sender network.DisconnectablePeer) (*crypto.Digest, bool) { var canonicalKey *crypto.Digest if handler.txCanonicalCache != nil { - if canonicalKey, isDup = handler.dedupCanonical(ntx, unverifiedTxGroup, consumed); isDup { + var isDup bool + if canonicalKey, isDup = handler.dedupCanonical(unverifiedTxGroup, encodedExpectedSize); isDup { transactionMessagesDupCanonical.Inc(nil) - return network.OutgoingMessage{Action: network.Ignore} + return canonicalKey, true } } // rate limit per application in a group. Limiting any app in a group drops the entire message. if handler.appLimiter != nil { congestedARL := len(handler.backlogQueue) > handler.appLimiterBacklogThreshold - if congestedARL && handler.appLimiter.shouldDrop(unverifiedTxGroup, rawmsg.Sender.(network.IPAddressable).RoutingAddr()) { + if congestedARL && handler.appLimiter.shouldDrop(unverifiedTxGroup, sender.(network.IPAddressable).RoutingAddr()) { transactionMessagesAppLimiterDrop.Inc(nil) - return network.OutgoingMessage{Action: network.Ignore} + return canonicalKey, true } } + return canonicalKey, false +} + +// processIncomingTxn decodes a transaction group from incoming message and enqueues into the back log for processing. +// The function also performs some input data pre-validation; +// - txn groups are cut to MaxTxGroupSize size +// - message are checked for duplicates +// - transactions are checked for duplicates +func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) network.OutgoingMessage { + msgKey, capguard, shouldDrop := handler.incomingMsgDupErlCheck(rawmsg.Data, rawmsg.Sender) + if shouldDrop { + // this TX message was found in the duplicate cache, or ERL rate-limited it + return network.OutgoingMessage{Action: network.Ignore} + } + + unverifiedTxGroup, consumed, invalid := decodeMsg(rawmsg.Data) + if invalid { + // invalid encoding or exceeding txgroup, disconnect from this peer + return network.OutgoingMessage{Action: network.Disconnect} + } + + canonicalKey, drop := handler.incomingTxGroupDupRateLimit(unverifiedTxGroup, consumed, rawmsg.Sender) + if drop { + // this re-serialized txgroup was detected as a duplicate by the canonical message cache, + // or it was rate-limited by the per-app rate limiter + return network.OutgoingMessage{Action: network.Ignore} + } select { case handler.backlogQueue <- &txBacklogMsg{ @@ -696,6 +750,75 @@ func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) net return network.OutgoingMessage{Action: network.Ignore} } +type validatedIncomingTxMessage struct { + rawmsg network.IncomingMessage + unverifiedTxGroup []transactions.SignedTxn + msgKey *crypto.Digest + canonicalKey *crypto.Digest + capguard *util.ErlCapacityGuard +} + +// validateIncomingTxMessage is the validator for the MessageProcessor implementation used by P2PNetwork. +func (handler *TxHandler) validateIncomingTxMessage(rawmsg network.IncomingMessage) network.ValidatedMessage { + msgKey, capguard, shouldDrop := handler.incomingMsgDupErlCheck(rawmsg.Data, rawmsg.Sender) + if shouldDrop { + // this TX message was found in the duplicate cache, or ERL rate-limited it + return network.ValidatedMessage{Action: network.Ignore, ValidatorData: nil} + } + + unverifiedTxGroup, consumed, invalid := decodeMsg(rawmsg.Data) + if invalid { + // invalid encoding or exceeding txgroup, disconnect from this peer + return network.ValidatedMessage{Action: network.Disconnect, ValidatorData: nil} + } + + canonicalKey, drop := handler.incomingTxGroupDupRateLimit(unverifiedTxGroup, consumed, rawmsg.Sender) + if drop { + // this re-serialized txgroup was detected as a duplicate by the canonical message cache, + // or it was rate-limited by the per-app rate limiter + return network.ValidatedMessage{Action: network.Ignore, ValidatorData: nil} + } + + return network.ValidatedMessage{ + Action: network.Accept, + Tag: rawmsg.Tag, + ValidatorData: &validatedIncomingTxMessage{ + rawmsg: rawmsg, + unverifiedTxGroup: unverifiedTxGroup, + msgKey: msgKey, + canonicalKey: canonicalKey, + capguard: capguard, + }, + } +} + +// processIncomingTxMessage is the handler for the MessageProcessor implementation used by P2PNetwork. +func (handler *TxHandler) processIncomingTxMessage(validatedMessage network.ValidatedMessage) network.OutgoingMessage { + msg := validatedMessage.ValidatorData.(*validatedIncomingTxMessage) + select { + case handler.backlogQueue <- &txBacklogMsg{ + rawmsg: &msg.rawmsg, + unverifiedTxGroup: msg.unverifiedTxGroup, + rawmsgDataHash: msg.msgKey, + unverifiedTxGroupHash: msg.canonicalKey, + capguard: msg.capguard, + }: + default: + // if we failed here we want to increase the corresponding metric. It might suggest that we + // want to increase the queue size. + transactionMessagesDroppedFromBacklog.Inc(nil) + + // additionally, remove the txn from duplicate caches to ensure it can be re-submitted + if handler.txCanonicalCache != nil && msg.canonicalKey != nil { + handler.txCanonicalCache.Delete(msg.canonicalKey) + } + if handler.msgCache != nil && msg.msgKey != nil { + handler.msgCache.DeleteByKey(msg.msgKey) + } + } + return network.OutgoingMessage{Action: network.Ignore} +} + var errBackLogFullLocal = errors.New("backlog full") // LocalTransaction is a special shortcut handler for local transactions and intended to be used diff --git a/network/gossipNode.go b/network/gossipNode.go index fb0a415876..eeeca95167 100644 --- a/network/gossipNode.go +++ b/network/gossipNode.go @@ -88,6 +88,12 @@ type GossipNode interface { // ClearHandlers deregisters all the existing message handlers. ClearHandlers() + // RegisterProcessors adds to the set of given message processors. + RegisterProcessors(dispatch []TaggedMessageProcessor) + + // ClearProcessors deregisters all the existing message processors. + ClearProcessors() + // GetHTTPClient returns a http.Client with a suitable for the network Transport // that would also limit the number of outgoing connections. GetHTTPClient(address string) (*http.Client, error) @@ -162,6 +168,14 @@ type OutgoingMessage struct { OnRelease func() } +// ValidatedMessage is a message that has been validated and is ready to be processed. +// Think as an intermediate one between IncomingMessage and OutgoingMessage +type ValidatedMessage struct { + Action ForwardingPolicy + Tag Tag + ValidatorData interface{} +} + // ForwardingPolicy is an enum indicating to whom we should send a message // //msgp:ignore ForwardingPolicy @@ -179,6 +193,9 @@ const ( // Respond - reply to the sender Respond + + // Accept - accept for further processing after successful validation + Accept ) // MessageHandler takes a IncomingMessage (e.g., vote, transaction), processes it, and returns what (if anything) @@ -189,20 +206,52 @@ type MessageHandler interface { Handle(message IncomingMessage) OutgoingMessage } -// HandlerFunc represents an implemenation of the MessageHandler interface +// HandlerFunc represents an implementation of the MessageHandler interface type HandlerFunc func(message IncomingMessage) OutgoingMessage -// Handle implements MessageHandler.Handle, calling the handler with the IncomingKessage and returning the OutgoingMessage +// Handle implements MessageHandler.Handle, calling the handler with the IncomingMessage and returning the OutgoingMessage func (f HandlerFunc) Handle(message IncomingMessage) OutgoingMessage { return f(message) } +// MessageProcessor takes a IncomingMessage (e.g., vote, transaction), processes it, and returns what (if anything) +// to send to the network in response. +// This is an extension of the MessageHandler that works in two stages: validate ->[result]-> handle. +type MessageProcessor interface { + Validate(message IncomingMessage) ValidatedMessage + Handle(message ValidatedMessage) OutgoingMessage +} + +// ProcessorValidateFunc represents an implementation of the MessageProcessor interface +type ProcessorValidateFunc func(message IncomingMessage) ValidatedMessage + +// ProcessorHandleFunc represents an implementation of the MessageProcessor interface +type ProcessorHandleFunc func(message ValidatedMessage) OutgoingMessage + +// Validate implements MessageProcessor.Validate, calling the validator with the IncomingMessage and returning the action +// and validation extra data that can be use as the handler input. +func (f ProcessorValidateFunc) Validate(message IncomingMessage) ValidatedMessage { + return f(message) +} + +// Handle implements MessageProcessor.Handle calling the handler with the ValidatedMessage and returning the OutgoingMessage +func (f ProcessorHandleFunc) Handle(message ValidatedMessage) OutgoingMessage { + return f(message) +} + // TaggedMessageHandler receives one type of broadcast messages type TaggedMessageHandler struct { Tag MessageHandler } +// TaggedMessageProcessor receives one type of broadcast messages +// and performs two stage processing: validating and handling +type TaggedMessageProcessor struct { + Tag + MessageProcessor +} + // Propagate is a convenience function to save typing in the common case of a message handler telling us to propagate an incoming message // "return network.Propagate(msg)" instead of "return network.OutgoingMsg{network.Broadcast, msg.Tag, msg.Data}" func Propagate(msg IncomingMessage) OutgoingMessage { diff --git a/network/hybridNetwork.go b/network/hybridNetwork.go index 7abb2ab569..6041d95f9a 100644 --- a/network/hybridNetwork.go +++ b/network/hybridNetwork.go @@ -180,6 +180,18 @@ func (n *HybridP2PNetwork) ClearHandlers() { n.wsNetwork.ClearHandlers() } +// RegisterProcessors adds to the set of given message handlers. +func (n *HybridP2PNetwork) RegisterProcessors(dispatch []TaggedMessageProcessor) { + n.p2pNetwork.RegisterProcessors(dispatch) + n.wsNetwork.RegisterProcessors(dispatch) +} + +// ClearProcessors deregisters all the existing message handlers. +func (n *HybridP2PNetwork) ClearProcessors() { + n.p2pNetwork.ClearProcessors() + n.wsNetwork.ClearProcessors() +} + // GetHTTPClient returns a http.Client with a suitable for the network Transport // that would also limit the number of outgoing connections. func (n *HybridP2PNetwork) GetHTTPClient(address string) (*http.Client, error) { diff --git a/network/multiplexer.go b/network/multiplexer.go index 0e97d63f28..2d69259c9d 100644 --- a/network/multiplexer.go +++ b/network/multiplexer.go @@ -24,32 +24,55 @@ import ( // Multiplexer is a message handler that sorts incoming messages by Tag and passes // them along to the relevant message handler for that type of message. type Multiplexer struct { - msgHandlers atomic.Value // stores map[Tag]MessageHandler, an immutable map. + msgHandlers atomic.Value // stores map[Tag]MessageHandler, an immutable map. + msgProcessors atomic.Value // stores map[Tag]MessageProcessor, an immutable map. } // MakeMultiplexer creates an empty Multiplexer func MakeMultiplexer() *Multiplexer { m := &Multiplexer{} - m.ClearHandlers([]Tag{}) // allocate the map + m.ClearHandlers(nil) // allocate the map + m.ClearProcessors(nil) // allocate the map return m } -// getHandlersMap retrieves the handlers map. -func (m *Multiplexer) getHandlersMap() map[Tag]MessageHandler { - handlersVal := m.msgHandlers.Load() - if handlers, valid := handlersVal.(map[Tag]MessageHandler); valid { +// getMap retrieves a typed map from an atomic.Value. +func getMap[T any](source *atomic.Value) map[Tag]T { + mp := source.Load() + if handlers, valid := mp.(map[Tag]T); valid { return handlers } return nil } -// Retrives the handler for the given message Tag from the handlers array while taking a read lock. -func (m *Multiplexer) getHandler(tag Tag) (MessageHandler, bool) { - if handlers := m.getHandlersMap(); handlers != nil { +// getHandlersMap retrieves the handlers map. +func (m *Multiplexer) getHandlersMap() map[Tag]MessageHandler { + return getMap[MessageHandler](&m.msgHandlers) +} + +// getProcessorsMap retrieves the processors map. +func (m *Multiplexer) getProcessorsMap() map[Tag]MessageProcessor { + return getMap[MessageProcessor](&m.msgHandlers) +} + +// Retrieves the handler for the given message Tag from the given value while. +func getHandler[T any](source *atomic.Value, tag Tag) (T, bool) { + if handlers := getMap[T](source); handlers != nil { handler, ok := handlers[tag] return handler, ok } - return nil, false + var empty T + return empty, false +} + +// Retrieves the handler for the given message Tag from the handlers array. +func (m *Multiplexer) getHandler(tag Tag) (MessageHandler, bool) { + return getHandler[MessageHandler](&m.msgHandlers, tag) +} + +// Retrieves the processor for the given message Tag from the processors array. +func (m *Multiplexer) getProcessor(tag Tag) (MessageProcessor, bool) { + return getHandler[MessageProcessor](&m.msgProcessors, tag) } // Handle is the "input" side of the multiplexer. It dispatches the message to the previously defined handler. @@ -63,6 +86,28 @@ func (m *Multiplexer) Handle(msg IncomingMessage) OutgoingMessage { return OutgoingMessage{} } +// Validate is an alternative "input" side of the multiplexer. It dispatches the message to the previously defined validator. +func (m *Multiplexer) Validate(msg IncomingMessage) ValidatedMessage { + handler, ok := m.getProcessor(msg.Tag) + + if ok { + outmsg := handler.Validate(msg) + return outmsg + } + return ValidatedMessage{} +} + +// Process is the second step of message handling after validation. It dispatches the message to the previously defined processor. +func (m *Multiplexer) Process(msg ValidatedMessage) OutgoingMessage { + handler, ok := m.getProcessor(msg.Tag) + + if ok { + outmsg := handler.Handle(msg) + return outmsg + } + return OutgoingMessage{} +} + // RegisterHandlers registers the set of given message handlers. func (m *Multiplexer) RegisterHandlers(dispatch []TaggedMessageHandler) { mp := make(map[Tag]MessageHandler) @@ -80,10 +125,27 @@ func (m *Multiplexer) RegisterHandlers(dispatch []TaggedMessageHandler) { m.msgHandlers.Store(mp) } -// ClearHandlers deregisters all the existing message handlers other than the one provided in the excludeTags list -func (m *Multiplexer) ClearHandlers(excludeTags []Tag) { +// RegisterProcessors registers the set of given message handlers. +func (m *Multiplexer) RegisterProcessors(dispatch []TaggedMessageProcessor) { + mp := make(map[Tag]MessageProcessor) + if existingMap := m.getProcessorsMap(); existingMap != nil { + for k, v := range existingMap { + mp[k] = v + } + } + for _, v := range dispatch { + if _, has := mp[v.Tag]; has { + panic(fmt.Sprintf("Already registered a handler for tag %v", v.Tag)) + } + mp[v.Tag] = v.MessageProcessor + } + m.msgProcessors.Store(mp) +} + +// ClearProcessors deregisters all the existing message handlers other than the one provided in the excludeTags list +func clear[T any](target *atomic.Value, excludeTags []Tag) { if len(excludeTags) == 0 { - m.msgHandlers.Store(make(map[Tag]MessageHandler)) + target.Store(make(map[Tag]T)) return } @@ -93,13 +155,23 @@ func (m *Multiplexer) ClearHandlers(excludeTags []Tag) { excludeTagsMap[tag] = true } - currentHandlersMap := m.getHandlersMap() - newMap := make(map[Tag]MessageHandler, len(excludeTagsMap)) - for tag, handler := range currentHandlersMap { + currentMap := getMap[T](target) + newMap := make(map[Tag]T, len(excludeTagsMap)) + for tag, handler := range currentMap { if excludeTagsMap[tag] { newMap[tag] = handler } } - m.msgHandlers.Store(newMap) + target.Store(newMap) +} + +// ClearHandlers deregisters all the existing message handlers other than the one provided in the excludeTags list +func (m *Multiplexer) ClearHandlers(excludeTags []Tag) { + clear[MessageHandler](&m.msgHandlers, excludeTags) +} + +// ClearProcessors deregisters all the existing message handlers other than the one provided in the excludeTags list +func (m *Multiplexer) ClearProcessors(excludeTags []Tag) { + clear[MessageProcessor](&m.msgProcessors, excludeTags) } diff --git a/network/p2pNetwork.go b/network/p2pNetwork.go index dd830ac5d7..5eaf6ec36f 100644 --- a/network/p2pNetwork.go +++ b/network/p2pNetwork.go @@ -659,6 +659,16 @@ func (n *P2PNetwork) ClearHandlers() { n.handler.ClearHandlers([]Tag{}) } +// RegisterProcessors adds to the set of given message handlers. +func (n *P2PNetwork) RegisterProcessors(dispatch []TaggedMessageProcessor) { + n.handler.RegisterProcessors(dispatch) +} + +// ClearProcessors deregisters all the existing message handlers. +func (n *P2PNetwork) ClearProcessors() { + n.handler.ClearProcessors([]Tag{}) +} + // GetHTTPClient returns a http.Client with a suitable for the network Transport // that would also limit the number of outgoing connections. func (n *P2PNetwork) GetHTTPClient(address string) (*http.Client, error) { @@ -884,11 +894,12 @@ func (n *P2PNetwork) txTopicHandleLoop() { sub.Cancel() return } + // if there is a self-sent the message no need to process it. + if msg.ReceivedFrom == n.service.ID() { + continue + } - // discard TX message. - // from gossipsub's point of view, it's just waiting to hear back from the validator, - // and txHandler does all its work in the validator, so we don't need to do anything here - _ = msg + _ = n.handler.Process(msg.ValidatorData.(ValidatedMessage)) // participation or configuration change, cancel subscription and quit if !n.wantTXGossip.Load() { @@ -923,14 +934,15 @@ func (n *P2PNetwork) txTopicValidator(ctx context.Context, peerID peer.ID, msg * peerStats.txReceived.Add(1) n.peerStatsMu.Unlock() - outmsg := n.handler.Handle(inmsg) + outmsg := n.handler.Validate(inmsg) // there was a decision made in the handler about this message switch outmsg.Action { case Ignore: return pubsub.ValidationIgnore case Disconnect: return pubsub.ValidationReject - case Broadcast: // TxHandler.processIncomingTxn does not currently return this Action + case Accept: + msg.ValidatorData = outmsg return pubsub.ValidationAccept default: n.log.Warnf("handler returned invalid action %d", outmsg.Action) diff --git a/network/p2pNetwork_test.go b/network/p2pNetwork_test.go index 4e929723f6..7cc590ff02 100644 --- a/network/p2pNetwork_test.go +++ b/network/p2pNetwork_test.go @@ -18,6 +18,7 @@ package network import ( "context" + "errors" "fmt" "io" "net/http" @@ -98,15 +99,26 @@ func TestP2PSubmitTX(t *testing.T) { // now we should be connected in a line: B <-> A <-> C where both B and C are connected to A but not each other // Since we aren't using the transaction handler in this test, we need to register a pass-through handler - passThroughHandler := []TaggedMessageHandler{ - {Tag: protocol.TxnTag, MessageHandler: HandlerFunc(func(msg IncomingMessage) OutgoingMessage { - return OutgoingMessage{Action: Broadcast} - })}, + passThroughHandler := []TaggedMessageProcessor{ + { + Tag: protocol.TxnTag, + MessageProcessor: struct { + ProcessorValidateFunc + ProcessorHandleFunc + }{ + ProcessorValidateFunc(func(msg IncomingMessage) ValidatedMessage { + return ValidatedMessage{Action: Accept, Tag: msg.Tag, ValidatorData: nil} + }), + ProcessorHandleFunc(func(msg ValidatedMessage) OutgoingMessage { + return OutgoingMessage{Action: Ignore} + }), + }, + }, } - netA.RegisterHandlers(passThroughHandler) - netB.RegisterHandlers(passThroughHandler) - netC.RegisterHandlers(passThroughHandler) + netA.RegisterProcessors(passThroughHandler) + netB.RegisterProcessors(passThroughHandler) + netC.RegisterProcessors(passThroughHandler) // send messages from B and confirm that they get received by C (via A) for i := 0; i < 10; i++ { @@ -178,14 +190,26 @@ func TestP2PSubmitTXNoGossip(t *testing.T) { time.Sleep(time.Second) // give time for peers to connect. // ensure netC cannot receive messages - passThroughHandler := []TaggedMessageHandler{ - {Tag: protocol.TxnTag, MessageHandler: HandlerFunc(func(msg IncomingMessage) OutgoingMessage { - return OutgoingMessage{Action: Broadcast} - })}, + + passThroughHandler := []TaggedMessageProcessor{ + { + Tag: protocol.TxnTag, + MessageProcessor: struct { + ProcessorValidateFunc + ProcessorHandleFunc + }{ + ProcessorValidateFunc(func(msg IncomingMessage) ValidatedMessage { + return ValidatedMessage{Action: Accept, Tag: msg.Tag, ValidatorData: nil} + }), + ProcessorHandleFunc(func(msg ValidatedMessage) OutgoingMessage { + return OutgoingMessage{Action: Ignore} + }), + }, + }, } - netB.RegisterHandlers(passThroughHandler) - netC.RegisterHandlers(passThroughHandler) + netB.RegisterProcessors(passThroughHandler) + netC.RegisterProcessors(passThroughHandler) for i := 0; i < 10; i++ { err = netA.Broadcast(context.Background(), protocol.TxnTag, []byte(fmt.Sprintf("test %d", i)), false, nil) require.NoError(t, err) @@ -207,7 +231,7 @@ func TestP2PSubmitTXNoGossip(t *testing.T) { 50*time.Millisecond, ) - // check netB did not receive the messages + // check netC did not receive the messages netC.peerStatsMu.Lock() _, ok := netC.peerStats[netA.service.ID()] netC.peerStatsMu.Unlock() @@ -804,9 +828,33 @@ func TestP2PRelay(t *testing.T) { return netA.hasPeers() && netB.hasPeers() }, 2*time.Second, 50*time.Millisecond) - counter := newMessageCounter(t, 1) - counterDone := counter.done - netA.RegisterHandlers([]TaggedMessageHandler{{Tag: protocol.TxnTag, MessageHandler: counter}}) + makeCounterHandler := func(numExpected int) ([]TaggedMessageProcessor, *int, chan struct{}) { + numActual := 0 + counterDone := make(chan struct{}) + counterHandler := []TaggedMessageProcessor{ + { + Tag: protocol.TxnTag, + MessageProcessor: struct { + ProcessorValidateFunc + ProcessorHandleFunc + }{ + ProcessorValidateFunc(func(msg IncomingMessage) ValidatedMessage { + return ValidatedMessage{Action: Accept, Tag: msg.Tag, ValidatorData: nil} + }), + ProcessorHandleFunc(func(msg ValidatedMessage) OutgoingMessage { + numActual++ + if numActual >= numExpected { + close(counterDone) + } + return OutgoingMessage{Action: Ignore} + }), + }, + }, + } + return counterHandler, &numActual, counterDone + } + counterHandler, _, counterDone := makeCounterHandler(1) + netA.RegisterProcessors(counterHandler) // send 5 messages from both netB to netA // since there is no node with listening address set => no messages should be received @@ -848,10 +896,9 @@ func TestP2PRelay(t *testing.T) { }, 2*time.Second, 50*time.Millisecond) const expectedMsgs = 10 - counter = newMessageCounter(t, expectedMsgs) - counterDone = counter.done - netA.ClearHandlers() - netA.RegisterHandlers([]TaggedMessageHandler{{Tag: protocol.TxnTag, MessageHandler: counter}}) + counterHandler, count, counterDone := makeCounterHandler(expectedMsgs) + netA.ClearProcessors() + netA.RegisterProcessors(counterHandler) for i := 0; i < expectedMsgs/2; i++ { err := netB.Relay(context.Background(), protocol.TxnTag, []byte{1, 2, 3, byte(i)}, true, nil) @@ -868,28 +915,41 @@ func TestP2PRelay(t *testing.T) { select { case <-counterDone: case <-time.After(2 * time.Second): - if counter.count < expectedMsgs { - require.Failf(t, "One or more messages failed to reach destination network", "%d > %d", expectedMsgs, counter.count) - } else if counter.count > expectedMsgs { - require.Failf(t, "One or more messages that were expected to be dropped, reached destination network", "%d < %d", expectedMsgs, counter.count) + if *count < expectedMsgs { + require.Failf(t, "One or more messages failed to reach destination network", "%d > %d", expectedMsgs, *count) + } else if *count > expectedMsgs { + require.Failf(t, "One or more messages that were expected to be dropped, reached destination network", "%d < %d", expectedMsgs, *count) } } } type mockSubPService struct { mockService - count atomic.Int64 + count atomic.Int64 + otherPeerID peer.ID + shouldNextFail bool } type mockSubscription struct { + peerID peer.ID + shouldNextFail bool } -func (m *mockSubscription) Next(ctx context.Context) (*pubsub.Message, error) { return nil, nil } -func (m *mockSubscription) Cancel() {} +func (m *mockSubscription) Next(ctx context.Context) (*pubsub.Message, error) { + if m.shouldNextFail { + return nil, errors.New("mockSubscription error") + } + return &pubsub.Message{ReceivedFrom: m.peerID}, nil +} +func (m *mockSubscription) Cancel() {} func (m *mockSubPService) Subscribe(topic string, val pubsub.ValidatorEx) (p2p.SubNextCancellable, error) { m.count.Add(1) - return &mockSubscription{}, nil + otherPeerID := m.otherPeerID + if otherPeerID == "" { + otherPeerID = "mockSubPServicePeerID" + } + return &mockSubscription{peerID: otherPeerID, shouldNextFail: m.shouldNextFail}, nil } // TestP2PWantTXGossip checks txTopicHandleLoop runs as expected on wantTXGossip changes @@ -900,7 +960,8 @@ func TestP2PWantTXGossip(t *testing.T) { // cancelled context to trigger subscription.Next to return ctx, cancel := context.WithCancel(context.Background()) cancel() - mockService := &mockSubPService{} + peerID := peer.ID("myPeerID") + mockService := &mockSubPService{mockService: mockService{id: peerID}, shouldNextFail: true} net := &P2PNetwork{ service: mockService, log: logging.TestingLog(t), diff --git a/network/wsNetwork.go b/network/wsNetwork.go index 50307f2738..13e0fe71f5 100644 --- a/network/wsNetwork.go +++ b/network/wsNetwork.go @@ -853,6 +853,14 @@ func (wn *WebsocketNetwork) ClearHandlers() { wn.handler.ClearHandlers([]Tag{protocol.PingTag, protocol.PingReplyTag, protocol.NetPrioResponseTag}) } +// RegisterProcessors registers the set of given message handlers. +func (wn *WebsocketNetwork) RegisterProcessors(dispatch []TaggedMessageProcessor) { +} + +// ClearProcessors deregisters all the existing message handlers. +func (wn *WebsocketNetwork) ClearProcessors() { +} + func (wn *WebsocketNetwork) setHeaders(header http.Header) { localTelemetryGUID := wn.log.GetTelemetryGUID() localInstanceName := wn.log.GetInstanceName()