diff --git a/network/p2p/client.go b/network/p2p/client.go index 1c3c9bee01d..9ae4bf92008 100644 --- a/network/p2p/client.go +++ b/network/p2p/client.go @@ -9,7 +9,6 @@ import ( "fmt" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/set" ) @@ -39,11 +38,17 @@ type CrossChainAppResponseCallback func( err error, ) +type clientSender interface { + AppRequestSender + AppGossipSender + CrossChainAppRequestSender +} + type Client struct { handlerID uint64 handlerPrefix []byte router *router - sender common.AppSender + sender clientSender options *clientOptions } @@ -88,7 +93,7 @@ func (c *Client) AppRequest( if err := c.sender.SendAppRequest( ctx, - set.Of(nodeID), + nodeID, requestID, appRequestBytes, ); err != nil { diff --git a/network/p2p/gossip/gossip.go b/network/p2p/gossip/gossip.go index 94d49260da4..05b46a0bd99 100644 --- a/network/p2p/gossip/gossip.go +++ b/network/p2p/gossip/gossip.go @@ -23,6 +23,7 @@ import ( var ( _ Gossiper = (*ValidatorGossiper)(nil) _ Gossiper = (*PullGossiper[testTx, *testTx])(nil) + _ Gossiper = (*NoOpGossiper)(nil) ) // Gossiper gossips Gossipables to other nodes @@ -196,3 +197,9 @@ func Every(ctx context.Context, log logging.Logger, gossiper Gossiper, frequency } } } + +type NoOpGossiper struct{} + +func (NoOpGossiper) Gossip(context.Context) error { + return nil +} diff --git a/network/p2p/gossip/gossip_test.go b/network/p2p/gossip/gossip_test.go index d30fac0008e..2823afbcf07 100644 --- a/network/p2p/gossip/gossip_test.go +++ b/network/p2p/gossip/gossip_test.go @@ -15,7 +15,6 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/network/p2p" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" ) @@ -115,8 +114,11 @@ func TestGossiperGossip(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require := require.New(t) + ctx := context.Background() - responseSender := &common.SenderTest{} + responseSender := &p2p.FakeSender{ + SentAppResponse: make(chan []byte, 1), + } responseNetwork := p2p.NewNetwork(logging.NoLog{}, responseSender, prometheus.NewRegistry(), "") responseBloom, err := NewBloomFilter(1000, 0.01) require.NoError(err) @@ -133,25 +135,13 @@ func TestGossiperGossip(t *testing.T) { _, err = responseNetwork.NewAppProtocol(0x0, handler) require.NoError(err) - requestSender := &common.SenderTest{ - SendAppRequestF: func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { - go func() { - require.NoError(responseNetwork.AppRequest(ctx, ids.EmptyNodeID, requestID, time.Time{}, request)) - }() - return nil - }, + requestSender := &p2p.FakeSender{ + SentAppRequest: make(chan []byte, 1), } requestNetwork := p2p.NewNetwork(logging.NoLog{}, requestSender, prometheus.NewRegistry(), "") require.NoError(requestNetwork.Connected(context.Background(), ids.EmptyNodeID, nil)) - gossiped := make(chan struct{}) - responseSender.SendAppResponseF = func(ctx context.Context, nodeID ids.NodeID, requestID uint32, appResponseBytes []byte) error { - require.NoError(requestNetwork.AppResponse(ctx, nodeID, requestID, appResponseBytes)) - close(gossiped) - return nil - } - bloom, err := NewBloomFilter(1000, 0.01) require.NoError(err) requestSet := testSet{ @@ -181,8 +171,9 @@ func TestGossiperGossip(t *testing.T) { received.Add(tx) } - require.NoError(gossiper.Gossip(context.Background())) - <-gossiped + require.NoError(gossiper.Gossip(ctx)) + require.NoError(responseNetwork.AppRequest(ctx, ids.EmptyNodeID, 1, time.Time{}, <-requestSender.SentAppRequest)) + require.NoError(requestNetwork.AppResponse(ctx, ids.EmptyNodeID, 1, <-responseSender.SentAppResponse)) require.Len(requestSet.set, tt.expectedLen) require.Subset(tt.expectedPossibleValues, requestSet.set.List()) diff --git a/network/p2p/handler.go b/network/p2p/handler.go index b85195a5255..2a12c82bd27 100644 --- a/network/p2p/handler.go +++ b/network/p2p/handler.go @@ -12,7 +12,6 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/utils/logging" ) @@ -92,7 +91,7 @@ type responder struct { Handler handlerID uint64 log logging.Logger - sender common.AppSender + sender routerSender } // AppRequest calls the underlying handler and sends back the response to nodeID diff --git a/network/p2p/mocks/mock_handler.go b/network/p2p/mocks/mock_handler.go deleted file mode 100644 index 0d4147d2318..00000000000 --- a/network/p2p/mocks/mock_handler.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ava-labs/avalanchego/network/p2p (interfaces: Handler) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - context "context" - reflect "reflect" - time "time" - - ids "github.com/ava-labs/avalanchego/ids" - gomock "go.uber.org/mock/gomock" -) - -// MockHandler is a mock of Handler interface. -type MockHandler struct { - ctrl *gomock.Controller - recorder *MockHandlerMockRecorder -} - -// MockHandlerMockRecorder is the mock recorder for MockHandler. -type MockHandlerMockRecorder struct { - mock *MockHandler -} - -// NewMockHandler creates a new mock instance. -func NewMockHandler(ctrl *gomock.Controller) *MockHandler { - mock := &MockHandler{ctrl: ctrl} - mock.recorder = &MockHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockHandler) EXPECT() *MockHandlerMockRecorder { - return m.recorder -} - -// AppGossip mocks base method. -func (m *MockHandler) AppGossip(arg0 context.Context, arg1 ids.NodeID, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AppGossip", arg0, arg1, arg2) -} - -// AppGossip indicates an expected call of AppGossip. -func (mr *MockHandlerMockRecorder) AppGossip(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppGossip", reflect.TypeOf((*MockHandler)(nil).AppGossip), arg0, arg1, arg2) -} - -// AppRequest mocks base method. -func (m *MockHandler) AppRequest(arg0 context.Context, arg1 ids.NodeID, arg2 time.Time, arg3 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AppRequest", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// AppRequest indicates an expected call of AppRequest. -func (mr *MockHandlerMockRecorder) AppRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppRequest", reflect.TypeOf((*MockHandler)(nil).AppRequest), arg0, arg1, arg2, arg3) -} - -// CrossChainAppRequest mocks base method. -func (m *MockHandler) CrossChainAppRequest(arg0 context.Context, arg1 ids.ID, arg2 time.Time, arg3 []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CrossChainAppRequest", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CrossChainAppRequest indicates an expected call of CrossChainAppRequest. -func (mr *MockHandlerMockRecorder) CrossChainAppRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CrossChainAppRequest", reflect.TypeOf((*MockHandler)(nil).CrossChainAppRequest), arg0, arg1, arg2, arg3) -} diff --git a/network/p2p/network.go b/network/p2p/network.go index 444c2e4b940..b3a22d52064 100644 --- a/network/p2p/network.go +++ b/network/p2p/network.go @@ -52,7 +52,7 @@ type clientOptions struct { // NewNetwork returns an instance of Network func NewNetwork( log logging.Logger, - sender common.AppSender, + sender AppSender, metrics prometheus.Registerer, namespace string, ) *Network { @@ -72,7 +72,7 @@ type Network struct { Peers *Peers log logging.Logger - sender common.AppSender + sender AppSender metrics prometheus.Registerer namespace string diff --git a/network/p2p/network_test.go b/network/p2p/network_test.go index 590858a0c46..bbcf62d3b32 100644 --- a/network/p2p/network_test.go +++ b/network/p2p/network_test.go @@ -5,19 +5,13 @@ package p2p import ( "context" - "sync" "testing" "time" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/network/p2p/mocks" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/math" @@ -25,289 +19,228 @@ import ( "github.com/ava-labs/avalanchego/version" ) +const handlerID = 1337 + +// Tests that the Client callback is called on a successful response func TestAppRequestResponse(t *testing.T) { - handlerID := uint64(0x0) - request := []byte("request") - response := []byte("response") - nodeID := ids.GenerateTestNodeID() - chainID := ids.GenerateTestID() + require := require.New(t) + ctx := context.Background() - ctxKey := new(string) - ctxVal := new(string) - *ctxKey = "foo" - *ctxVal = "bar" + sender := FakeSender{ + SentAppRequest: make(chan []byte, 1), + } + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - tests := []struct { - name string - requestFunc func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) - }{ - { - name: "app request", - requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.SendAppRequestF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { - for range nodeIDs { - go func() { - require.NoError(t, network.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) - }() - } + client, err := network.NewAppProtocol(handlerID, &NoOpHandler{}) + require.NoError(err) - return nil - } - sender.SendAppResponseF = func(ctx context.Context, _ ids.NodeID, requestID uint32, response []byte) error { - go func() { - ctx = context.WithValue(ctx, ctxKey, ctxVal) - require.NoError(t, network.AppResponse(ctx, nodeID, requestID, response)) - }() + wantResponse := []byte("response") + wantNodeID := ids.GenerateTestNodeID() + done := make(chan struct{}) - return nil - } - handler.EXPECT(). - AppRequest(context.Background(), nodeID, gomock.Any(), request). - DoAndReturn(func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { - return response, nil - }) - - callback := func(ctx context.Context, actualNodeID ids.NodeID, actualResponse []byte, err error) { - defer wg.Done() - - require.NoError(t, err) - require.Equal(t, ctxVal, ctx.Value(ctxKey)) - require.Equal(t, nodeID, actualNodeID) - require.Equal(t, response, actualResponse) - } + callback := func(_ context.Context, gotNodeID ids.NodeID, gotResponse []byte, err error) { + require.Equal(wantNodeID, gotNodeID) + require.NoError(err) + require.Equal(wantResponse, gotResponse) - require.NoError(t, client.AppRequestAny(context.Background(), request, callback)) - }, - }, - { - name: "app request failed", - requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.SendAppRequestF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { - for range nodeIDs { - go func() { - require.NoError(t, network.AppRequestFailed(ctx, nodeID, requestID)) - }() - } + close(done) + } - return nil - } + require.NoError(client.AppRequest(ctx, set.Of(wantNodeID), []byte("request"), callback)) + <-sender.SentAppRequest - callback := func(_ context.Context, actualNodeID ids.NodeID, actualResponse []byte, err error) { - defer wg.Done() + require.NoError(network.AppResponse(ctx, wantNodeID, 1, wantResponse)) + <-done +} - require.ErrorIs(t, err, ErrAppRequestFailed) - require.Equal(t, nodeID, actualNodeID) - require.Nil(t, actualResponse) - } +// Tests that the Client callback is given an error if the request fails +func TestAppRequestFailed(t *testing.T) { + require := require.New(t) + ctx := context.Background() - require.NoError(t, client.AppRequest(context.Background(), set.Of(nodeID), request, callback)) - }, - }, - { - name: "cross-chain app request", - requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { - chainID := ids.GenerateTestID() - sender.SendCrossChainAppRequestF = func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { - go func() { - require.NoError(t, network.CrossChainAppRequest(ctx, chainID, requestID, time.Time{}, request)) - }() - } - sender.SendCrossChainAppResponseF = func(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) { - go func() { - ctx = context.WithValue(ctx, ctxKey, ctxVal) - require.NoError(t, network.CrossChainAppResponse(ctx, chainID, requestID, response)) - }() - } - handler.EXPECT(). - CrossChainAppRequest(context.Background(), chainID, gomock.Any(), request). - DoAndReturn(func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { - return response, nil - }) - - callback := func(ctx context.Context, actualChainID ids.ID, actualResponse []byte, err error) { - defer wg.Done() - require.NoError(t, err) - require.Equal(t, ctxVal, ctx.Value(ctxKey)) - require.Equal(t, chainID, actualChainID) - require.Equal(t, response, actualResponse) - } + sender := FakeSender{ + SentAppRequest: make(chan []byte, 1), + } + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) - }, - }, - { - name: "cross-chain app request failed", - requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.SendCrossChainAppRequestF = func(ctx context.Context, chainID ids.ID, requestID uint32, request []byte) { - go func() { - require.NoError(t, network.CrossChainAppRequestFailed(ctx, chainID, requestID)) - }() - } + client, err := network.NewAppProtocol(handlerID, &NoOpHandler{}) + require.NoError(err) - callback := func(_ context.Context, actualChainID ids.ID, actualResponse []byte, err error) { - defer wg.Done() + wantNodeID := ids.GenerateTestNodeID() + done := make(chan struct{}) - require.ErrorIs(t, err, ErrAppRequestFailed) - require.Equal(t, chainID, actualChainID) - require.Nil(t, actualResponse) - } + callback := func(_ context.Context, gotNodeID ids.NodeID, gotResponse []byte, err error) { + require.Equal(wantNodeID, gotNodeID) + require.ErrorIs(err, ErrAppRequestFailed) + require.Nil(gotResponse) - require.NoError(t, client.CrossChainAppRequest(context.Background(), chainID, request, callback)) - }, - }, - { - name: "app gossip", - requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.SendAppGossipF = func(ctx context.Context, gossip []byte) error { - go func() { - require.NoError(t, network.AppGossip(ctx, nodeID, gossip)) - }() + close(done) + } - return nil - } - handler.EXPECT(). - AppGossip(context.Background(), nodeID, request). - DoAndReturn(func(context.Context, ids.NodeID, []byte) error { - defer wg.Done() - return nil - }) - - require.NoError(t, client.AppGossip(context.Background(), request)) - }, - }, - { - name: "app gossip specific", - requestFunc: func(t *testing.T, network *Network, client *Client, sender *common.SenderTest, handler *mocks.MockHandler, wg *sync.WaitGroup) { - sender.SendAppGossipSpecificF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], bytes []byte) error { - for n := range nodeIDs { - nodeID := n - go func() { - require.NoError(t, network.AppGossip(ctx, nodeID, bytes)) - }() - } + require.NoError(client.AppRequest(ctx, set.Of(wantNodeID), []byte("request"), callback)) + <-sender.SentAppRequest - return nil - } - handler.EXPECT(). - AppGossip(context.Background(), nodeID, request). - DoAndReturn(func(context.Context, ids.NodeID, []byte) error { - defer wg.Done() - return nil - }) - - require.NoError(t, client.AppGossipSpecific(context.Background(), set.Of(nodeID), request)) - }, - }, + require.NoError(network.AppRequestFailed(ctx, wantNodeID, 1)) + <-done +} + +// Tests that the Client callback is called on a successful response +func TestCrossChainAppRequestResponse(t *testing.T) { + require := require.New(t) + ctx := context.Background() + + sender := FakeSender{ + SentCrossChainAppRequest: make(chan []byte, 1), } + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require := require.New(t) - ctrl := gomock.NewController(t) + client, err := network.NewAppProtocol(handlerID, &NoOpHandler{}) + require.NoError(err) - sender := &common.SenderTest{} - handler := mocks.NewMockHandler(ctrl) - n := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - require.NoError(n.Connected(context.Background(), nodeID, nil)) - client, err := n.NewAppProtocol(handlerID, handler) - require.NoError(err) + wantChainID := ids.GenerateTestID() + wantResponse := []byte("response") + done := make(chan struct{}) - wg := &sync.WaitGroup{} - wg.Add(1) - tt.requestFunc(t, n, client, sender, handler, wg) - wg.Wait() - }) + callback := func(_ context.Context, gotChainID ids.ID, gotResponse []byte, err error) { + require.Equal(wantChainID, gotChainID) + require.NoError(err) + require.Equal(wantResponse, gotResponse) + + close(done) } + + require.NoError(client.CrossChainAppRequest(ctx, wantChainID, []byte("request"), callback)) + <-sender.SentCrossChainAppRequest + + require.NoError(network.CrossChainAppResponse(ctx, wantChainID, 1, wantResponse)) + <-done } -func TestNetworkDropMessage(t *testing.T) { - unregistered := byte(0x0) +// Tests that the Client callback is given an error if the request fails +func TestCrossChainAppRequestFailed(t *testing.T) { + require := require.New(t) + ctx := context.Background() + + sender := FakeSender{ + SentCrossChainAppRequest: make(chan []byte, 1), + } + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") + + client, err := network.NewAppProtocol(handlerID, &NoOpHandler{}) + require.NoError(err) + + wantChainID := ids.GenerateTestID() + done := make(chan struct{}) + callback := func(_ context.Context, gotChainID ids.ID, gotResponse []byte, err error) { + require.Equal(wantChainID, gotChainID) + require.ErrorIs(err, ErrAppRequestFailed) + require.Nil(gotResponse) + + close(done) + } + + require.NoError(client.CrossChainAppRequest(ctx, wantChainID, []byte("request"), callback)) + <-sender.SentCrossChainAppRequest + + require.NoError(network.CrossChainAppRequestFailed(ctx, wantChainID, 1)) + <-done +} + +// Messages for unregistered handlers should be dropped gracefully +func TestMessageForUnregisteredHandler(t *testing.T) { tests := []struct { - name string - requestFunc func(network *Network) error - err error + name string + msg []byte }{ { - name: "drop unregistered app request message", - requestFunc: func(network *Network) error { - return network.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{unregistered}) - }, - err: nil, - }, - { - name: "drop empty app request message", - requestFunc: func(network *Network) error { - return network.AppRequest(context.Background(), ids.GenerateTestNodeID(), 0, time.Time{}, []byte{}) - }, - err: nil, + name: "nil", + msg: nil, }, { - name: "drop unregistered cross-chain app request message", - requestFunc: func(network *Network) error { - return network.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{unregistered}) - }, - err: nil, + name: "empty", + msg: []byte{}, }, { - name: "drop empty cross-chain app request message", - requestFunc: func(network *Network) error { - return network.CrossChainAppRequest(context.Background(), ids.GenerateTestID(), 0, time.Time{}, []byte{}) - }, - err: nil, - }, - { - name: "drop unregistered gossip message", - requestFunc: func(network *Network) error { - return network.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{unregistered}) - }, - err: nil, - }, - { - name: "drop empty gossip message", - requestFunc: func(network *Network) error { - return network.AppGossip(context.Background(), ids.GenerateTestNodeID(), []byte{}) - }, - err: nil, - }, - { - name: "drop unrequested app request failed", - requestFunc: func(network *Network) error { - return network.AppRequestFailed(context.Background(), ids.GenerateTestNodeID(), 0) - }, - err: ErrUnrequestedResponse, + name: "non-empty", + msg: []byte("foobar"), }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + ctx := context.Background() + handler := &testHandler{ + appGossipF: func(context.Context, ids.NodeID, []byte) { + require.Fail("should not be called") + }, + appRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { + require.Fail("should not be called") + return nil, nil + }, + crossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { + require.Fail("should not be called") + return nil, nil + }, + } + network := NewNetwork(logging.NoLog{}, nil, prometheus.NewRegistry(), "") + _, err := network.NewAppProtocol(handlerID, handler) + require.NoError(err) + + require.Nil(network.AppRequest(ctx, ids.EmptyNodeID, 0, time.Time{}, []byte("foobar"))) + require.Nil(network.AppGossip(ctx, ids.EmptyNodeID, []byte("foobar"))) + require.Nil(network.CrossChainAppRequest(ctx, ids.Empty, 0, time.Time{}, []byte("foobar"))) + }) + } +} + +// A response or timeout for a request we never made should return an error +func TestResponseForUnrequestedRequest(t *testing.T) { + tests := []struct { + name string + msg []byte + }{ { - name: "drop unrequested app response", - requestFunc: func(network *Network) error { - return network.AppResponse(context.Background(), ids.GenerateTestNodeID(), 0, nil) - }, - err: ErrUnrequestedResponse, + name: "nil", + msg: nil, }, { - name: "drop unrequested cross-chain request failed", - requestFunc: func(network *Network) error { - return network.CrossChainAppRequestFailed(context.Background(), ids.GenerateTestID(), 0) - }, - err: ErrUnrequestedResponse, + name: "empty", + msg: []byte{}, }, { - name: "drop unrequested cross-chain response", - requestFunc: func(network *Network) error { - return network.CrossChainAppResponse(context.Background(), ids.GenerateTestID(), 0, nil) - }, - err: ErrUnrequestedResponse, + name: "non-empty", + msg: []byte("foobar"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require := require.New(t) + ctx := context.Background() + handler := &testHandler{ + appGossipF: func(context.Context, ids.NodeID, []byte) { + require.Fail("should not be called") + }, + appRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) { + require.Fail("should not be called") + return nil, nil + }, + crossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) { + require.Fail("should not be called") + return nil, nil + }, + } + network := NewNetwork(logging.NoLog{}, nil, prometheus.NewRegistry(), "") + _, err := network.NewAppProtocol(handlerID, handler) + require.NoError(err) - network := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "") + require.ErrorIs(ErrUnrequestedResponse, network.AppResponse(ctx, ids.EmptyNodeID, 0, []byte("foobar"))) + require.ErrorIs(ErrUnrequestedResponse, network.AppRequestFailed(ctx, ids.EmptyNodeID, 0)) - err := tt.requestFunc(network) - require.ErrorIs(err, tt.err) + require.ErrorIs(ErrUnrequestedResponse, network.CrossChainAppResponse(ctx, ids.Empty, 0, []byte("foobar"))) + require.ErrorIs(ErrUnrequestedResponse, network.CrossChainAppRequestFailed(ctx, ids.Empty, 0)) }) } } @@ -317,58 +250,25 @@ func TestNetworkDropMessage(t *testing.T) { // not attempt to issue another request until the previous one has cleared. func TestAppRequestDuplicateRequestIDs(t *testing.T) { require := require.New(t) - ctrl := gomock.NewController(t) + ctx := context.Background() - handler := mocks.NewMockHandler(ctrl) - sender := &common.SenderTest{ - SendAppResponseF: func(context.Context, ids.NodeID, uint32, []byte) error { - return nil - }, - } - network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") - nodeID := ids.GenerateTestNodeID() - - requestSent := &sync.WaitGroup{} - sender.SendAppRequestF = func(ctx context.Context, nodeIDs set.Set[ids.NodeID], requestID uint32, request []byte) error { - for range nodeIDs { - requestSent.Add(1) - go func() { - require.NoError(network.AppRequest(ctx, nodeID, requestID, time.Time{}, request)) - requestSent.Done() - }() - } - - return nil + sender := &FakeSender{ + SentAppRequest: make(chan []byte, 1), } - timeout := &sync.WaitGroup{} - response := []byte("response") - handler.EXPECT().AppRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, nodeID ids.NodeID, deadline time.Time, request []byte) ([]byte, error) { - timeout.Wait() - return response, nil - }).AnyTimes() - - require.NoError(network.Connected(context.Background(), nodeID, nil)) - client, err := network.NewAppProtocol(0x1, handler) + network := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "") + client, err := network.NewAppProtocol(0x1, &NoOpHandler{}) require.NoError(err) - onResponse := func(ctx context.Context, nodeID ids.NodeID, got []byte, err error) { - require.NoError(err) - require.Equal(response, got) - } - - require.NoError(client.AppRequest(context.Background(), set.Of(nodeID), []byte{}, onResponse)) - requestSent.Wait() + noOpCallback := func(context.Context, ids.NodeID, []byte, error) {} + // create a request that never gets a response + require.NoError(client.AppRequest(ctx, set.Of(ids.EmptyNodeID), []byte{}, noOpCallback)) + <-sender.SentAppRequest // force the network to use the same requestID network.router.requestID = 1 - timeout.Add(1) - err = client.AppRequest(context.Background(), set.Of(nodeID), []byte{}, nil) - requestSent.Wait() + err = client.AppRequest(context.Background(), set.Of(ids.EmptyNodeID), []byte{}, noOpCallback) require.ErrorIs(err, ErrRequestPending) - - timeout.Done() } // Sample should always return up to [limit] peers, and less if fewer than @@ -437,7 +337,7 @@ func TestPeersSample(t *testing.T) { t.Run(tt.name, func(t *testing.T) { require := require.New(t) - network := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "") + network := NewNetwork(logging.NoLog{}, &FakeSender{}, prometheus.NewRegistry(), "") for connected := range tt.connected { require.NoError(network.Connected(context.Background(), connected, nil)) @@ -479,11 +379,9 @@ func TestAppRequestAnyNodeSelection(t *testing.T) { require := require.New(t) sent := set.Set[ids.NodeID]{} - sender := &common.SenderTest{ - SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error { - for nodeID := range nodeIDs { - sent.Add(nodeID) - } + sender := &MockSender{ + SendAppRequestF: func(_ context.Context, nodeID ids.NodeID, _ uint32, _ []byte) error { + sent.Add(nodeID) return nil }, } @@ -498,6 +396,7 @@ func TestAppRequestAnyNodeSelection(t *testing.T) { err = client.AppRequestAny(context.Background(), []byte("foobar"), nil) require.ErrorIs(err, tt.expected) + require.Subset(tt.peers, sent.List()) }) } } @@ -569,9 +468,9 @@ func TestNodeSamplerClientOption(t *testing.T) { require := require.New(t) done := make(chan struct{}) - sender := &common.SenderTest{ - SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error { - require.Subset(tt.expected, nodeIDs.List()) + sender := &MockSender{ + SendAppRequestF: func(_ context.Context, nodeID ids.NodeID, _ uint32, _ []byte) error { + require.Subset(tt.expected, []ids.NodeID{nodeID}) close(done) return nil }, diff --git a/network/p2p/router.go b/network/p2p/router.go index 110e9b6de62..773060e4b4b 100644 --- a/network/p2p/router.go +++ b/network/p2p/router.go @@ -55,12 +55,17 @@ type meteredHandler struct { *metrics } +type routerSender interface { + AppResponseSender + CrossChainAppResponseSender +} + // router routes incoming application messages to the corresponding registered // app handler. App messages must be made using the registered handler's // corresponding Client. type router struct { log logging.Logger - sender common.AppSender + sender routerSender metrics prometheus.Registerer namespace string @@ -74,7 +79,7 @@ type router struct { // newRouter returns a new instance of Router func newRouter( log logging.Logger, - sender common.AppSender, + sender routerSender, metrics prometheus.Registerer, namespace string, ) *router { diff --git a/network/p2p/sender.go b/network/p2p/sender.go new file mode 100644 index 00000000000..8db23c2f54b --- /dev/null +++ b/network/p2p/sender.go @@ -0,0 +1,194 @@ +package p2p + +import ( + "context" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/set" +) + +var ( + _ AppSender = (*Sender)(nil) + _ AppSender = (*FakeSender)(nil) +) + +type AppRequestSender interface { + SendAppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, bytes []byte) error +} + +type AppResponseSender interface { + SendAppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, bytes []byte) error +} + +type AppGossipSender interface { + SendAppGossip(ctx context.Context, bytes []byte) error + SendAppGossipSpecific(ctx context.Context, nodeIDs set.Set[ids.NodeID], bytes []byte) error +} + +type CrossChainAppRequestSender interface { + SendCrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, bytes []byte) error +} + +type CrossChainAppResponseSender interface { + SendCrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, bytes []byte) error +} + +type AppSender interface { + AppRequestSender + AppResponseSender + AppGossipSender + CrossChainAppRequestSender + CrossChainAppResponseSender +} + +func NewSender(sender common.AppSender) *Sender { + return &Sender{ + sender: sender, + } +} + +type Sender struct { + sender common.AppSender +} + +func (s Sender) SendAppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, bytes []byte) error { + return s.sender.SendAppRequest(ctx, set.Of(nodeID), requestID, bytes) +} + +func (s Sender) SendAppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, bytes []byte) error { + return s.sender.SendAppResponse(ctx, nodeID, requestID, bytes) +} + +func (s Sender) SendAppGossip(ctx context.Context, bytes []byte) error { + return s.sender.SendAppGossip(ctx, bytes) +} + +func (s Sender) SendAppGossipSpecific(ctx context.Context, nodeIDs set.Set[ids.NodeID], bytes []byte) error { + return s.sender.SendAppGossipSpecific(ctx, nodeIDs, bytes) +} + +func (s Sender) SendCrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, bytes []byte) error { + return s.sender.SendCrossChainAppRequest(ctx, chainID, requestID, bytes) +} + +func (s Sender) SendCrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, bytes []byte) error { + return s.sender.SendCrossChainAppResponse(ctx, chainID, requestID, bytes) +} + +type FakeSender struct { + SentAppRequest, SentAppResponse, + SentAppGossip, SentAppGossipSpecific, + SentCrossChainAppRequest, SentCrossChainAppResponse chan []byte +} + +func (f FakeSender) SendAppRequest(_ context.Context, _ ids.NodeID, _ uint32, bytes []byte) error { + if f.SentAppRequest == nil { + return nil + } + + f.SentAppRequest <- bytes + return nil +} + +func (f FakeSender) SendAppResponse(_ context.Context, _ ids.NodeID, _ uint32, bytes []byte) error { + if f.SentAppResponse == nil { + return nil + } + + f.SentAppResponse <- bytes + return nil +} + +func (f FakeSender) SendAppGossip(_ context.Context, bytes []byte) error { + if f.SentAppGossip == nil { + return nil + } + + f.SentAppGossip <- bytes + return nil +} + +func (f FakeSender) SendAppGossipSpecific(_ context.Context, _ set.Set[ids.NodeID], bytes []byte) error { + if f.SentAppGossipSpecific == nil { + return nil + } + + f.SentAppGossipSpecific <- bytes + return nil +} + +func (f FakeSender) SendCrossChainAppRequest(_ context.Context, _ ids.ID, _ uint32, bytes []byte) error { + if f.SentCrossChainAppRequest == nil { + return nil + } + + f.SentCrossChainAppRequest <- bytes + return nil +} + +func (f FakeSender) SendCrossChainAppResponse(_ context.Context, _ ids.ID, _ uint32, bytes []byte) error { + if f.SentCrossChainAppResponse == nil { + return nil + } + + f.SentCrossChainAppResponse <- bytes + return nil +} + +type MockSender struct { + SendAppRequestF func(context.Context, ids.NodeID, uint32, []byte) error + SendAppResponseF func(context.Context, ids.NodeID, uint32, []byte) error + SendAppGossipF func(context.Context, []byte) error + SendAppGossipSpecificF func(context.Context, set.Set[ids.NodeID], []byte) error + SendCrossChainAppRequestF func(context.Context, ids.ID, uint32, []byte) error + SendCrossChainAppResponseF func(context.Context, ids.ID, uint32, []byte) error +} + +func (f MockSender) SendAppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, bytes []byte) error { + if f.SendAppRequestF == nil { + return nil + } + + return f.SendAppRequestF(ctx, nodeID, requestID, bytes) +} + +func (f MockSender) SendAppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, bytes []byte) error { + if f.SendAppResponseF == nil { + return nil + } + + return f.SendAppResponseF(ctx, nodeID, requestID, bytes) +} + +func (f MockSender) SendAppGossip(ctx context.Context, bytes []byte) error { + if f.SendAppGossipF == nil { + return nil + } + + return f.SendAppGossipF(ctx, bytes) +} + +func (f MockSender) SendAppGossipSpecific(ctx context.Context, nodeIDs set.Set[ids.NodeID], bytes []byte) error { + if f.SendAppGossipSpecificF == nil { + return nil + } + + return f.SendAppGossipSpecificF(ctx, nodeIDs, bytes) +} + +func (f MockSender) SendCrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, bytes []byte) error { + if f.SendCrossChainAppRequestF == nil { + return nil + } + + return f.SendCrossChainAppRequestF(ctx, chainID, requestID, bytes) +} + +func (f MockSender) SendCrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, bytes []byte) error { + if f.SendCrossChainAppResponseF == nil { + return nil + } + + return f.SendCrossChainAppResponseF(ctx, chainID, requestID, bytes) +} diff --git a/network/p2p/validators_test.go b/network/p2p/validators_test.go index e721b4a978a..325e7e721d2 100644 --- a/network/p2p/validators_test.go +++ b/network/p2p/validators_test.go @@ -16,7 +16,6 @@ import ( "go.uber.org/mock/gomock" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/utils/logging" ) @@ -179,7 +178,7 @@ func TestValidatorsSample(t *testing.T) { } gomock.InOrder(calls...) - network := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "") + network := NewNetwork(logging.NoLog{}, &FakeSender{}, prometheus.NewRegistry(), "") ctx := context.Background() require.NoError(network.Connected(ctx, nodeID1, nil)) require.NoError(network.Connected(ctx, nodeID2, nil)) diff --git a/scripts/mocks.mockgen.txt b/scripts/mocks.mockgen.txt index 76add90f3f7..21df924dd28 100644 --- a/scripts/mocks.mockgen.txt +++ b/scripts/mocks.mockgen.txt @@ -6,7 +6,6 @@ github.com/ava-labs/avalanchego/database=Iterator=database/mock_iterator.go github.com/ava-labs/avalanchego/message=OutboundMessage=message/mock_message.go github.com/ava-labs/avalanchego/message=OutboundMsgBuilder=message/mock_outbound_message_builder.go github.com/ava-labs/avalanchego/network/peer=GossipTracker=network/peer/mock_gossip_tracker.go -github.com/ava-labs/avalanchego/network/p2p=Handler=network/p2p/mocks/mock_handler.go github.com/ava-labs/avalanchego/snow/consensus/snowman=Block=snow/consensus/snowman/mock_block.go github.com/ava-labs/avalanchego/snow/engine/avalanche/vertex=LinearizableVM=snow/engine/avalanche/vertex/mock_vm.go github.com/ava-labs/avalanchego/snow/engine/snowman/block=BuildBlockWithContextChainVM=snow/engine/snowman/block/mocks/build_block_with_context_vm.go