diff --git a/network/p2p/p2p.go b/network/p2p/p2p.go index 51485ce78b..de420a0837 100644 --- a/network/p2p/p2p.go +++ b/network/p2p/p2p.go @@ -195,6 +195,9 @@ type connLimitConfig struct { func deriveConnLimits(cfg config.Local) connLimitConfig { var low, high, rcmgrConns, rcmgrConnsInbound, rcmgrConnsOutbound int rcmgrConnsOutbound = cfg.GossipFanout * 3 + if cfg.EnableDHTProviders { + rcmgrConnsOutbound += cfg.GossipFanout * 3 + } if cfg.IsListenServer() { if cfg.IncomingConnectionsLimit < 0 { rcmgrConns = math.MaxInt @@ -340,8 +343,18 @@ func (s *serviceImpl) DialPeersUntilTargetCount(targetConnCount int) bool { if numOutgoingConns >= targetConnCount { return numOutgoingConns > preExistingConns } - // if we are already connected to this peer, skip it - if len(s.host.Network().ConnsToPeer(peerInfo.ID)) > 0 { + // if we are already connected to this peer, ensure it's properly handled + if conns := s.host.Network().ConnsToPeer(peerInfo.ID); len(conns) > 0 { + if !s.host.ConnManager().IsProtected(peerInfo.ID, cnmgrTag) { + // connection was established by DHT/pubsub before the mesh thread + // could protect it, so handleConnected skipped stream creation. + // protect and re-trigger stream setup now. + s.host.ConnManager().Protect(peerInfo.ID, cnmgrTag) + go s.streams.handleConnected(conns[0]) + if conns[0].Stat().Direction == network.DirOutbound { + numOutgoingConns++ + } + } continue } err := s.dialNode(context.Background(), peerInfo) // leaving the calls as blocking for now, to not over-connect beyond fanout diff --git a/network/p2p/p2p_test.go b/network/p2p/p2p_test.go index 035aebc90d..9f15616b57 100644 --- a/network/p2p/p2p_test.go +++ b/network/p2p/p2p_test.go @@ -359,6 +359,23 @@ func TestDeriveConnLimits_UnboundedServer(t *testing.T) { require.Equal(t, 0, limits.connMgrLow) } +func TestDeriveConnLimits_DHTProviders(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + cfg := config.GetDefaultLocal() + cfg.NetAddress = ":4160" + cfg.IncomingConnectionsLimit = 2400 + cfg.GossipFanout = 4 + cfg.EnableDHTProviders = true + limits := deriveConnLimits(cfg) + require.Equal(t, 2400+12+12, limits.rcmgrConns) + require.Equal(t, 2400, limits.rcmgrConnsInbound) + require.Equal(t, 24, limits.rcmgrConnsOutbound) + require.Equal(t, 2424, limits.connMgrHigh) + require.Equal(t, 2327, limits.connMgrLow) // 2424 * 96 / 100 +} + func TestDeriveConnLimits_Client(t *testing.T) { partitiontest.PartitionTest(t) t.Parallel() diff --git a/network/p2p/streams.go b/network/p2p/streams.go index cfae8cf679..97282c7ac8 100644 --- a/network/p2p/streams.go +++ b/network/p2p/streams.go @@ -98,7 +98,8 @@ func (n *streamManager) streamHandler(stream network.Stream) { // an error occurred while checking the old stream n.log.Infof("Failed to check old stream with %s: %v", remotePeer, err) } - n.streams[stream.Conn().RemotePeer()] = stream + // old stream is dead, remove + delete(n.streams, remotePeer) incoming := stream.Conn().Stat().Direction == network.DirInbound if err1 := n.dispatch(n.ctx, remotePeer, stream, incoming); err1 != nil { @@ -106,6 +107,7 @@ func (n *streamManager) streamHandler(stream network.Stream) { _ = stream.Reset() return } + n.streams[stream.Conn().RemotePeer()] = stream dispatched = true return } @@ -115,7 +117,6 @@ func (n *streamManager) streamHandler(stream network.Stream) { return } // no old stream - n.streams[stream.Conn().RemotePeer()] = stream incoming := stream.Conn().Stat().Direction == network.DirInbound if err := n.dispatch(n.ctx, remotePeer, stream, incoming); err != nil { n.log.Errorln(err.Error()) @@ -123,6 +124,8 @@ func (n *streamManager) streamHandler(stream network.Stream) { return } + n.streams[stream.Conn().RemotePeer()] = stream + dispatched = true } @@ -143,21 +146,12 @@ func (n *streamManager) dispatch(ctx context.Context, remotePeer peer.ID, stream // We do some read/write operations in this handler for metadata exchange that creates a race condition // with StopNotify on network shutdown. To avoid, run the handler as a goroutine. func (n *streamManager) Connected(net network.Network, conn network.Conn) { - go n.handleConnected(conn) -} - -func (n *streamManager) handleConnected(conn network.Conn) { - dispatched := false - defer func() { - if !dispatched { - n.host.ConnManager().Unprotect(conn.RemotePeer(), cnmgrTag) - } - }() remotePeer := conn.RemotePeer() localPeer := n.host.ID() if conn.Stat().Direction == network.DirInbound && !n.allowIncomingGossip { n.log.Debugf("%s: ignoring incoming connection from %s", localPeer.String(), remotePeer.String()) + n.host.ConnManager().Unprotect(conn.RemotePeer(), cnmgrTag) return } @@ -166,7 +160,6 @@ func (n *streamManager) handleConnected(conn network.Conn) { // so mark dispatched to preserve the cnmgr protection set by dialNode. if localPeer > remotePeer { n.log.Debugf("%s: ignoring a lesser peer ID %s", localPeer.String(), remotePeer.String()) - dispatched = true return } @@ -179,10 +172,23 @@ func (n *streamManager) handleConnected(conn network.Conn) { } } + go n.handleConnected(conn) +} + +func (n *streamManager) handleConnected(conn network.Conn) { + dispatched := false + defer func() { + if !dispatched { + n.host.ConnManager().Unprotect(conn.RemotePeer(), cnmgrTag) + } + }() + remotePeer := conn.RemotePeer() + localPeer := n.host.ID() + n.streamsLock.Lock() _, ok := n.streams[remotePeer] + n.streamsLock.Unlock() if ok { - n.streamsLock.Unlock() n.log.Debugf("%s: already have a stream to/from %s", localPeer.String(), remotePeer.String()) dispatched = true return // there's already an active stream with this peer for our protocol @@ -195,12 +201,8 @@ func (n *streamManager) handleConnected(conn network.Conn) { stream, err := n.host.NewStream(n.ctx, remotePeer, protos...) if err != nil { n.log.Infof("%s: failed to open stream to %s (%s): %v", localPeer.String(), remotePeer, conn.RemoteMultiaddr().String(), err) - n.streamsLock.Unlock() return } - n.streams[remotePeer] = stream - n.streamsLock.Unlock() - n.log.Infof("%s: using protocol %s with peer %s", localPeer.String(), stream.Protocol(), remotePeer.String()) incoming := stream.Conn().Stat().Direction == network.DirInbound @@ -210,6 +212,20 @@ func (n *streamManager) handleConnected(conn network.Conn) { return } + n.streamsLock.Lock() + defer n.streamsLock.Unlock() + if _, exists := n.streams[remotePeer]; exists { + // another stream was added in the meantime, close this one and keep the existing one + _ = stream.Reset() + dispatched = true + return + } + // don't add disconnected / died conns, so Disconnect won't need to clean up + if stream.Conn().IsClosed() { + _ = stream.Reset() + return // dispatched is still false + } + n.streams[remotePeer] = stream dispatched = true } diff --git a/network/p2p/streams_stale_test.go b/network/p2p/streams_stale_test.go new file mode 100644 index 0000000000..84709e8c7c --- /dev/null +++ b/network/p2p/streams_stale_test.go @@ -0,0 +1,430 @@ +// Copyright (C) 2019-2026 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package p2p + +import ( + "context" + "errors" + "testing" + "time" + + connmgrcore "github.com/libp2p/go-libp2p/core/connmgr" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" + + "github.com/algorand/go-deadlock" + + "github.com/algorand/go-algorand/logging" + "github.com/algorand/go-algorand/test/partitiontest" +) + +// errDispatchFailed is the sentinel error returned by the failing test handler. +var errDispatchFailed = errors.New("dispatch failed") + +const testProto = protocol.ID("/algorand-test/1.0.0") + +// mockConnMgr implements connmgrcore.ConnManager for testing. +type mockConnMgr struct { + mu deadlock.Mutex + protected map[peer.ID]map[string]bool +} + +func newMockConnMgr() *mockConnMgr { + return &mockConnMgr{protected: make(map[peer.ID]map[string]bool)} +} + +func (m *mockConnMgr) Protect(id peer.ID, tag string) { + m.mu.Lock() + defer m.mu.Unlock() + if m.protected[id] == nil { + m.protected[id] = make(map[string]bool) + } + m.protected[id][tag] = true +} + +func (m *mockConnMgr) Unprotect(id peer.ID, tag string) bool { + m.mu.Lock() + defer m.mu.Unlock() + if m.protected[id] != nil { + delete(m.protected[id], tag) + } + return len(m.protected[id]) > 0 +} + +func (m *mockConnMgr) IsProtected(id peer.ID, tag string) bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.protected[id] != nil && m.protected[id][tag] +} + +func (m *mockConnMgr) TagPeer(peer.ID, string, int) {} +func (m *mockConnMgr) UntagPeer(peer.ID, string) {} +func (m *mockConnMgr) UpsertTag(peer.ID, string, func(int) int) {} +func (m *mockConnMgr) GetTagInfo(peer.ID) *connmgrcore.TagInfo { return nil } +func (m *mockConnMgr) TrimOpenConns(context.Context) {} +func (m *mockConnMgr) Notifee() network.Notifiee { return nil } +func (m *mockConnMgr) CheckLimit(connmgrcore.GetConnLimiter) error { return nil } +func (m *mockConnMgr) Close() error { return nil } + +// mockHost implements host.Host with only the methods used by streamManager. +type mockHost struct { + id peer.ID + cm *mockConnMgr + newStreamFn func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) +} + +func (h *mockHost) ID() peer.ID { return h.id } +func (h *mockHost) ConnManager() connmgrcore.ConnManager { return h.cm } +func (h *mockHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (network.Stream, error) { + return h.newStreamFn(ctx, p, pids...) +} + +func (h *mockHost) Peerstore() peerstore.Peerstore { panic("unused") } +func (h *mockHost) Addrs() []ma.Multiaddr { panic("unused") } +func (h *mockHost) Network() network.Network { panic("unused") } +func (h *mockHost) Mux() protocol.Switch { panic("unused") } +func (h *mockHost) Connect(context.Context, peer.AddrInfo) error { panic("unused") } +func (h *mockHost) SetStreamHandler(protocol.ID, network.StreamHandler) {} +func (h *mockHost) SetStreamHandlerMatch(protocol.ID, func(protocol.ID) bool, network.StreamHandler) { +} +func (h *mockHost) RemoveStreamHandler(protocol.ID) {} +func (h *mockHost) Close() error { return nil } +func (h *mockHost) EventBus() event.Bus { panic("unused") } + +// Verify interface satisfaction at compile time. +var _ host.Host = (*mockHost)(nil) + +// mockConn implements network.Conn with controllable direction and peer IDs. +type mockConn struct { + remotePeerID peer.ID + localPeerID peer.ID + dir network.Direction +} + +func newMockConn(local, remote peer.ID, dir network.Direction) *mockConn { + return &mockConn{localPeerID: local, remotePeerID: remote, dir: dir} +} + +func (c *mockConn) Close() error { return nil } +func (c *mockConn) LocalPeer() peer.ID { return c.localPeerID } +func (c *mockConn) RemotePeer() peer.ID { return c.remotePeerID } +func (c *mockConn) RemotePublicKey() ic.PubKey { return nil } +func (c *mockConn) ConnState() network.ConnectionState { return network.ConnectionState{} } +func (c *mockConn) LocalMultiaddr() ma.Multiaddr { return ma.StringCast("/ip4/127.0.0.1/tcp/4190") } +func (c *mockConn) RemoteMultiaddr() ma.Multiaddr { return ma.StringCast("/ip4/1.2.3.4/tcp/4190") } +func (c *mockConn) Stat() network.ConnStats { + return network.ConnStats{Stats: network.Stats{Direction: c.dir}} +} +func (c *mockConn) Scope() network.ConnScope { return nil } +func (c *mockConn) ID() string { return "mock-conn" } +func (c *mockConn) NewStream(context.Context) (network.Stream, error) { panic("unused") } +func (c *mockConn) GetStreams() []network.Stream { return nil } +func (c *mockConn) IsClosed() bool { return false } + +var _ network.Conn = (*mockConn)(nil) + +// mockStream implements network.Stream with controllable behavior. +type mockStream struct { + mu deadlock.Mutex + conn *mockConn + proto protocol.ID + dir network.Direction + readErr error // error returned by Read + resetCalled bool + closeCalled bool +} + +func newMockStream(conn *mockConn, proto protocol.ID, dir network.Direction) *mockStream { + return &mockStream{conn: conn, proto: proto, dir: dir} +} + +func (s *mockStream) Read(p []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.readErr != nil { + return 0, s.readErr + } + return 0, nil +} + +func (s *mockStream) Write(p []byte) (int, error) { return len(p), nil } + +func (s *mockStream) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closeCalled = true + return nil +} + +func (s *mockStream) CloseRead() error { return nil } +func (s *mockStream) CloseWrite() error { return nil } + +func (s *mockStream) Reset() error { + s.mu.Lock() + defer s.mu.Unlock() + s.resetCalled = true + return nil +} + +func (s *mockStream) SetDeadline(time.Time) error { return nil } +func (s *mockStream) SetReadDeadline(time.Time) error { return nil } +func (s *mockStream) SetWriteDeadline(time.Time) error { return nil } +func (s *mockStream) Protocol() protocol.ID { return s.proto } +func (s *mockStream) SetProtocol(protocol.ID) error { return nil } +func (s *mockStream) Stat() network.Stats { return network.Stats{Direction: s.dir} } +func (s *mockStream) Conn() network.Conn { return s.conn } +func (s *mockStream) ID() string { return "mock-stream" } +func (s *mockStream) Scope() network.StreamScope { return nil } + +func (s *mockStream) wasReset() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.resetCalled +} + +var _ network.Stream = (*mockStream)(nil) + +// failingHandler is a StreamHandler that always returns errDispatchFailed. +func failingHandler(_ context.Context, _ peer.ID, _ network.Stream, _ bool) error { + return errDispatchFailed +} + +// newTestStreamManager creates a streamManager with a failing handler for testProto. +func newTestStreamManager(localID peer.ID, allowIncoming bool) (*streamManager, *mockHost) { + cm := newMockConnMgr() + h := &mockHost{id: localID, cm: cm} + handlers := StreamHandlers{ + {ProtoID: testProto, Handler: failingHandler}, + } + logger := logging.NewLogger() + logger.SetLevel(logging.Debug) + sm := makeStreamManager(context.Background(), logger, h, handlers, allowIncoming) + return sm, h +} + +// assertStreamMapEmpty checks that sm.streams has no entry for remotePeer. +func assertStreamMapEmpty(t *testing.T, sm *streamManager, remotePeer peer.ID) { + t.Helper() + sm.streamsLock.Lock() + defer sm.streamsLock.Unlock() + _, exists := sm.streams[remotePeer] + require.False(t, exists, "expected n.streams[%s] to be cleaned up after dispatch failure", remotePeer) +} + +// --- test cases --- + +// TestStream_MapCleanupOnDispatchFailure verifies that n.streams is cleaned up +// when dispatch (V22 handshake) fails, across all 8 combinations: +// +// directions (inbound/outbound) × +// peer ID orderings (local < remote / local > remote) × +// dial origins (dialNode / DHT-pubsub) +func TestStream_MapCleanupOnDispatchFailure(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + // deterministic peer IDs + lowPeer := peer.ID("AAAA-low-peer") + highPeer := peer.ID("ZZZZ-high-peer") + require.True(t, lowPeer < highPeer) + + // handleConnected path — local node initiates the stream + + // case 1: outbound, localPeer < remotePeer, dialNode + // Connected, peer ID passes, protected, handleConnected, dispatch fails + t.Run("outbound_localLow_dialNode", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(lowPeer, true) + conn := newMockConn(lowPeer, highPeer, network.DirOutbound) + stream := newMockStream(conn, testProto, network.DirOutbound) + h.newStreamFn = func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) { + return stream, nil + } + // dialNode would have protected the peer before dialing + h.cm.Protect(highPeer, cnmgrTag) + + sm.handleConnected(conn) + + assertStreamMapEmpty(t, sm, highPeer) + require.True(t, stream.wasReset()) + require.False(t, h.cm.IsProtected(highPeer, cnmgrTag)) + }) + + // case 2: outbound, localPeer < remotePeer, DHT dial + // Connected skips (unprotected). DialPeersUntilTargetCount, Protect, handleConnected, dispatch fails + t.Run("outbound_localLow_dhtDial", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(lowPeer, true) + conn := newMockConn(lowPeer, highPeer, network.DirOutbound) + stream := newMockStream(conn, testProto, network.DirOutbound) + h.newStreamFn = func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) { + return stream, nil + } + // DialPeersUntilTargetCount protects then calls handleConnected + h.cm.Protect(highPeer, cnmgrTag) + + sm.handleConnected(conn) + + assertStreamMapEmpty(t, sm, highPeer) + require.True(t, stream.wasReset()) + require.False(t, h.cm.IsProtected(highPeer, cnmgrTag)) + }) + + // case 3: outbound, localPeer > remotePeer, DHT dial + // Connected defers (peer ID). Remote's stream rejected (unprotected outbound). + // DialPeersUntilTargetCount, Protect, handleConnected (new code skips peer ID check), dispatch fails + t.Run("outbound_localHigh_dhtDial", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(highPeer, true) + conn := newMockConn(highPeer, lowPeer, network.DirOutbound) + stream := newMockStream(conn, testProto, network.DirOutbound) + h.newStreamFn = func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) { + return stream, nil + } + h.cm.Protect(lowPeer, cnmgrTag) + + sm.handleConnected(conn) + + assertStreamMapEmpty(t, sm, lowPeer) + require.True(t, stream.wasReset()) + require.False(t, h.cm.IsProtected(lowPeer, cnmgrTag)) + }) + + // case 4: inbound, localPeer < remotePeer, remote's dialNode dialed + // Connected, inbound gossip OK, peer ID passes, handleConnected, dispatch fails + t.Run("inbound_localLow_remoteDial", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(lowPeer, true) + conn := newMockConn(lowPeer, highPeer, network.DirInbound) + stream := newMockStream(conn, testProto, network.DirInbound) + h.newStreamFn = func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) { + return stream, nil + } + + sm.handleConnected(conn) + + assertStreamMapEmpty(t, sm, highPeer) + require.True(t, stream.wasReset()) + }) + + // case 5: inbound, localPeer < remotePeer, remote's DHT dialed + t.Run("inbound_localLow_remoteDHT", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(lowPeer, true) + conn := newMockConn(lowPeer, highPeer, network.DirInbound) + stream := newMockStream(conn, testProto, network.DirInbound) + h.newStreamFn = func(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) { + return stream, nil + } + // No protection from our side (remote's inbound conn) + + sm.handleConnected(conn) + + assertStreamMapEmpty(t, sm, highPeer) + require.True(t, stream.wasReset()) + // Unprotect was called but was a no-op (nothing was protected) + require.False(t, h.cm.IsProtected(highPeer, cnmgrTag)) + }) + + // streamHandler path — remote peer creates the stream, our node handles it + + // case 6: outbound, localPeer > remotePeer, our dialNode + // Connected defers (peer ID). Remote opens stream, our streamHandler, dispatch fails + t.Run("outbound_localHigh_ourDial", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(highPeer, true) + // Connection is outbound (we dialed), stream is inbound (remote initiated) + conn := newMockConn(highPeer, lowPeer, network.DirOutbound) + stream := newMockStream(conn, testProto, network.DirInbound) + // Our dialNode protected this peer + h.cm.Protect(lowPeer, cnmgrTag) + + sm.streamHandler(stream) + + assertStreamMapEmpty(t, sm, lowPeer) + require.True(t, stream.wasReset()) + // dispatched=false => Unprotect called + require.False(t, h.cm.IsProtected(lowPeer, cnmgrTag)) + }) + + // case 7: inbound, localPeer > remotePeer, remote's dialNode dialed us + // Connected defers (peer ID). Remote opens stream, our streamHandler, dispatch fails + t.Run("inbound_localHigh_remoteDial", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(highPeer, true) + // Connection is inbound (remote dialed us), stream is inbound (remote initiated) + conn := newMockConn(highPeer, lowPeer, network.DirInbound) + stream := newMockStream(conn, testProto, network.DirInbound) + + sm.streamHandler(stream) + + assertStreamMapEmpty(t, sm, lowPeer) + require.True(t, stream.wasReset()) + // No protection was set, so Unprotect is a no-op + require.False(t, h.cm.IsProtected(lowPeer, cnmgrTag)) + }) + + // case 8: inbound, localPeer > remotePeer, remote's DHT dialed us + // Connected defers (peer ID). Remote's DHT connection; + // remote opens stream, our streamHandler, dispatch fails + t.Run("inbound_localHigh_remoteDHT", func(t *testing.T) { + t.Parallel() + sm, h := newTestStreamManager(highPeer, true) + conn := newMockConn(highPeer, lowPeer, network.DirInbound) + stream := newMockStream(conn, testProto, network.DirInbound) + + sm.streamHandler(stream) + + assertStreamMapEmpty(t, sm, lowPeer) + require.True(t, stream.wasReset()) + require.False(t, h.cm.IsProtected(lowPeer, cnmgrTag)) + }) +} + +// TestStream_HandlerCleanupReplacingDeadStream verifies that when streamHandler +// replaces a dead stream and the new dispatch also fails, the map entry is cleaned up. +func TestStream_HandlerCleanupReplacingDeadStream(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + localID := peer.ID("ZZZZ-high-peer") + remoteID := peer.ID("AAAA-low-peer") + sm, h := newTestStreamManager(localID, true) + + // Pre-populate n.streams with a dead (reset) stream + conn := newMockConn(localID, remoteID, network.DirInbound) + deadStream := newMockStream(conn, testProto, network.DirInbound) + deadStream.readErr = network.ErrReset // Read returns error => stream is dead + sm.streams[remoteID] = deadStream + + // Protect so that Unprotect tracking works + h.cm.Protect(remoteID, cnmgrTag) + + // New stream arrives from remote peer, dispatch will fail + newStream := newMockStream(conn, testProto, network.DirInbound) + sm.streamHandler(newStream) + + assertStreamMapEmpty(t, sm, remoteID) + require.True(t, newStream.wasReset(), "new stream should be reset on dispatch failure") +} diff --git a/network/p2pMetainfo.go b/network/p2pMetainfo.go index 39497bbaed..10f11faa8d 100644 --- a/network/p2pMetainfo.go +++ b/network/p2pMetainfo.go @@ -70,17 +70,17 @@ type peerMetaInfo struct { func readPeerMetaHeaders(stream io.ReadWriter, p2pPeer peer.ID, netProtoSupportedVersions []string) (peerMetaInfo, error) { var msgLenBytes [2]byte - rn, err := stream.Read(msgLenBytes[:]) - if rn != 2 || err != nil { + _, err := io.ReadFull(stream, msgLenBytes[:]) + if err != nil { err0 := fmt.Errorf("error reading response message length from peer %s: %w", p2pPeer, err) return peerMetaInfo{}, err0 } msgLen := binary.BigEndian.Uint16(msgLenBytes[:]) msgBytes := make([]byte, msgLen) - rn, err = stream.Read(msgBytes[:]) - if rn != int(msgLen) || err != nil { - err0 := fmt.Errorf("error reading response message from peer %s: %w, expected: %d, read: %d", p2pPeer, err, msgLen, rn) + _, err = io.ReadFull(stream, msgBytes) + if err != nil { + err0 := fmt.Errorf("error reading response message from peer %s: %w", p2pPeer, err) return peerMetaInfo{}, err0 } var responseHeaders peerMetaHeaders diff --git a/network/p2pMetainfo_test.go b/network/p2pMetainfo_test.go index f5b9b89bfa..50aaf0f89a 100644 --- a/network/p2pMetainfo_test.go +++ b/network/p2pMetainfo_test.go @@ -81,9 +81,10 @@ func TestReadPeerMetaHeaders(t *testing.T) { assert.Equal(t, "mockFeatures", metaInfo.features) mockStream.AssertExpectations(t) - // Error case: incomplete length read + // Error case: incomplete length read then EOF mockStream = new(MockStream) mockStream.On("Read", mock.Anything).Return([]byte{1}, nil).Once() + mockStream.On("Read", mock.Anything).Return([]byte{}, fmt.Errorf("EOF")).Once() _, err = readPeerMetaHeaders(mockStream, p2pPeer, n.supportedProtocolVersions) assert.ErrorContains(t, err, "error reading response message length") mockStream.AssertExpectations(t) @@ -95,10 +96,11 @@ func TestReadPeerMetaHeaders(t *testing.T) { assert.ErrorContains(t, err, "error reading response message length") mockStream.AssertExpectations(t) - // Error case: incomplete message read + // Error case: incomplete message read then EOF mockStream = new(MockStream) mockStream.On("Read", mock.Anything).Return(lengthBytes, nil).Once() - mockStream.On("Read", mock.Anything).Return(data[:len(data)/2], nil).Once() // Return only half the data + mockStream.On("Read", mock.Anything).Return(data[:len(data)/2], nil).Once() + mockStream.On("Read", mock.Anything).Return([]byte{}, fmt.Errorf("EOF")).Once() _, err = readPeerMetaHeaders(mockStream, p2pPeer, n.supportedProtocolVersions) assert.ErrorContains(t, err, "error reading response message") mockStream.AssertExpectations(t) @@ -136,6 +138,26 @@ func TestReadPeerMetaHeaders(t *testing.T) { _, err = readPeerMetaHeaders(mockStream, p2pPeer, n.supportedProtocolVersions) assert.ErrorContains(t, err, "does not support any of the supported protocol versions") mockStream.AssertExpectations(t) + + // Verify short reads are handled: length arrives in two reads + mockStream = new(MockStream) + mockStream.On("Read", mock.Anything).Return(lengthBytes[:1], nil).Once() + mockStream.On("Read", mock.Anything).Return(lengthBytes[1:], nil).Once() + mockStream.On("Read", mock.Anything).Return(data, nil).Once() + metaInfo, err = readPeerMetaHeaders(mockStream, p2pPeer, n.supportedProtocolVersions) + assert.NoError(t, err) + assert.Equal(t, "1.0", metaInfo.version) + mockStream.AssertExpectations(t) + + // Verify short reads are handled: body arrives in two reads + mockStream = new(MockStream) + mockStream.On("Read", mock.Anything).Return(lengthBytes, nil).Once() + mockStream.On("Read", mock.Anything).Return(data[:len(data)/2], nil).Once() + mockStream.On("Read", mock.Anything).Return(data[len(data)/2:], nil).Once() + metaInfo, err = readPeerMetaHeaders(mockStream, p2pPeer, n.supportedProtocolVersions) + assert.NoError(t, err) + assert.Equal(t, "1.0", metaInfo.version) + mockStream.AssertExpectations(t) } func TestWritePeerMetaHeaders(t *testing.T) { diff --git a/util/execpool/stream_test.go b/util/execpool/stream_test.go index fe7bbd227b..00a7d35673 100644 --- a/util/execpool/stream_test.go +++ b/util/execpool/stream_test.go @@ -122,8 +122,8 @@ func testStreamToBatchCore(wg *sync.WaitGroup, mockJobs <-chan *mockJob, done <- sv.WaitForStop() } -// TestStreamToBatchBasic tests the basic functionality -func TestStreamToBatchBasic(t *testing.T) { +// TestStream_ToBatchBasic tests the basic functionality +func TestStream_ToBatchBasic(t *testing.T) { partitiontest.PartitionTest(t) numJobs := 400 @@ -197,8 +197,8 @@ func TestStreamToBatchBasic(t *testing.T) { } } -// TestNoInputYet let the service start and get to the timeout without any inputs -func TestNoInputYet(t *testing.T) { +// TestStream_NoInputYet let the service start and get to the timeout without any inputs +func TestStream_NoInputYet(t *testing.T) { partitiontest.PartitionTest(t) numJobs := 1 @@ -228,8 +228,8 @@ func TestNoInputYet(t *testing.T) { wg.Wait() } -// TestMutipleBatchAttempts tests the behavior when multiple batch attempts will fail and the stream blocks -func TestMutipleBatchAttempts(t *testing.T) { +// TestStream_MultipleBatchAttempts tests the behavior when multiple batch attempts will fail and the stream blocks +func TestStream_MultipleBatchAttempts(t *testing.T) { partitiontest.PartitionTest(t) mp := mockPool{ @@ -297,9 +297,9 @@ func TestMutipleBatchAttempts(t *testing.T) { sv.WaitForStop() } -// TestErrors tests all the cases where exec pool returned error is handled +// TestStream_Errors tests all the cases where exec pool returned error is handled // by ending the stream processing -func TestErrors(t *testing.T) { +func TestStream_Errors(t *testing.T) { partitiontest.PartitionTest(t) mp := mockPool{ @@ -371,9 +371,9 @@ func TestErrors(t *testing.T) { sv.WaitForStop() } -// TestPendingJobOnRestart makes sure a pending job in the exec pool is cancled -// when the Stream ctx is cancled, and a now one started with a new ctx -func TestPendingJobOnRestart(t *testing.T) { +// TestStream_PendingJobOnRestart makes sure a pending job in the exec pool is canceled +// when the Stream ctx is canceled, and a new one started with a new ctx +func TestStream_PendingJobOnRestart(t *testing.T) { partitiontest.PartitionTest(t) mp := mockPool{ @@ -429,7 +429,7 @@ func TestPendingJobOnRestart(t *testing.T) { <-mp.asyncDelay <-mp.asyncDelay - // wait for the notifiation from cleanup before checking the TestPendingJobOnRestart + // wait for the notifiation from cleanup before checking the TestStream_PendingJobOnRestart <-mbp.notify require.Error(t, mj.returnError) require.False(t, mj.processed)