diff --git a/headertest/dummy_header.go b/headertest/dummy_header.go index a6576e93..39f932f0 100644 --- a/headertest/dummy_header.go +++ b/headertest/dummy_header.go @@ -115,11 +115,11 @@ func (d *DummyHeader) Verify(header header.Header) error { } if header.Height() <= d.Height() { - return fmt.Errorf("expected new header Height to be larger than old header Time") + return fmt.Errorf("expected new header Height %d to be larger than old header Height %d", header.Height(), d.Height()) } if header.Time().Before(d.Time()) { - return fmt.Errorf("expected new header Time to be after old header Time") + return fmt.Errorf("expected new header Time %v to be after old header Time %v", header.Time(), d.Time()) } return nil diff --git a/headertest/dummy_suite.go b/headertest/dummy_suite.go index 740c0d6c..f5a81bee 100644 --- a/headertest/dummy_suite.go +++ b/headertest/dummy_suite.go @@ -42,6 +42,7 @@ func (s *DummySuite) NextHeader() *DummyHeader { } dh := RandDummyHeader(s.t) + dh.Raw.Time = s.head.Time().Add(time.Nanosecond) dh.Raw.Height = s.head.Height() + 1 dh.Raw.PreviousHash = s.head.Hash() _ = dh.rehash() diff --git a/headertest/store.go b/headertest/store.go index 2785b876..0fe9d867 100644 --- a/headertest/store.go +++ b/headertest/store.go @@ -48,7 +48,7 @@ func (m *Store[H]) Height() uint64 { return uint64(m.HeadHeight) } -func (m *Store[H]) Head(context.Context) (H, error) { +func (m *Store[H]) Head(context.Context, ...header.HeadOption) (H, error) { return m.Headers[m.HeadHeight], nil } diff --git a/interface.go b/interface.go index 57fe1574..e9596196 100644 --- a/interface.go +++ b/interface.go @@ -127,5 +127,5 @@ type Getter[H Header] interface { // reporting it. type Head[H Header] interface { // Head returns the latest known header. - Head(context.Context) (H, error) + Head(context.Context, ...HeadOption) (H, error) } diff --git a/local/exchange.go b/local/exchange.go index 68866fd5..df1858e3 100644 --- a/local/exchange.go +++ b/local/exchange.go @@ -26,7 +26,7 @@ func (l *Exchange[H]) Stop(context.Context) error { return nil } -func (l *Exchange[H]) Head(ctx context.Context) (H, error) { +func (l *Exchange[H]) Head(ctx context.Context, _ ...header.HeadOption) (H, error) { return l.store.Head(ctx) } diff --git a/opts.go b/opts.go new file mode 100644 index 00000000..dcf04d8d --- /dev/null +++ b/opts.go @@ -0,0 +1,17 @@ +package header + +type HeadOption func(opts *HeadParams) + +// HeadParams contains options to be used for Head interface methods +type HeadParams struct { + // TrustedHead allows the caller of Head to specify a trusted header + // against which the underlying implementation of Head can verify against. + TrustedHead Header +} + +// WithTrustedHead sets the TrustedHead parameter to the given header. +func WithTrustedHead(verified Header) func(opts *HeadParams) { + return func(opts *HeadParams) { + opts.TrustedHead = verified + } +} diff --git a/p2p/exchange.go b/p2p/exchange.go index 99d7f87c..9d41c9d8 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -3,7 +3,6 @@ package p2p import ( "bytes" "context" - "errors" "fmt" "math/rand" "sort" @@ -21,10 +20,15 @@ import ( var log = logging.Logger("header/p2p") -// the minimum number of headers of the same height received from trusted peers -// to determine the network head. If all trusted header will return headers with -// non-equal height, then the highest header will be chosen. -const minTrustedHeadResponses = 2 +// minHeadResponses is the minimum number of headers of the same height +// received from peers to determine the network head. If all trusted peers +// will return headers with non-equal height, then the highest header will be +// chosen. +const minHeadResponses = 2 + +// maxUntrustedHeadRequests is the number of head requests to be made to +// the network in order to determine the network head. +var maxUntrustedHeadRequests = 4 // Exchange enables sending outbound HeaderRequests to the network as well as // handling inbound HeaderRequests from the network. @@ -72,26 +76,16 @@ func NewExchange[H header.Header]( return ex, nil } -func (ex *Exchange[H]) Start(context.Context) error { +func (ex *Exchange[H]) Start(ctx context.Context) error { ex.ctx, ex.cancel = context.WithCancel(context.Background()) log.Infow("client: starting client", "protocol ID", ex.protocolID) - trustedPeers := ex.trustedPeers() - - for _, p := range trustedPeers { - // Try to pre-connect to trusted peers. - // We don't really care if we succeed at this point - // and just need any peers in the peerTracker asap - go func(p peer.ID) { - err := ex.host.Connect(ex.ctx, peer.AddrInfo{ID: p}) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - log.Debugw("err connecting to a bootstrap peer", "err", err, "peer", p) - } - }(p) - } go ex.peerTracker.gc() go ex.peerTracker.track() - return nil + + // bootstrap the peerTracker with trusted peers as well as previously seen + // peers if provided. + return ex.peerTracker.bootstrap(ctx, ex.trustedPeers()) } func (ex *Exchange[H]) Stop(ctx context.Context) error { @@ -106,7 +100,7 @@ func (ex *Exchange[H]) Stop(ctx context.Context) error { // The Head must be verified thereafter where possible. // We request in parallel all the trusted peers, compare their response // and return the highest one. -func (ex *Exchange[H]) Head(ctx context.Context) (H, error) { +func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption) (H, error) { log.Debug("requesting head") reqCtx := ctx @@ -121,30 +115,61 @@ func (ex *Exchange[H]) Head(ctx context.Context) (H, error) { defer cancel() } + reqParams := header.HeadParams{} + for _, opt := range opts { + opt(&reqParams) + } + + peers := ex.trustedPeers() + + // the TrustedHead field indicates whether the Exchange should use + // trusted peers for its Head request. If nil, trusted peers will + // be used. If non-nil, Exchange will ask several peers from its network for + // their Head and verify against the given trusted header. + useTrackedPeers := reqParams.TrustedHead != nil + if useTrackedPeers { + trackedPeers := ex.peerTracker.getPeers(maxUntrustedHeadRequests) + if len(trackedPeers) > 0 { + peers = trackedPeers + log.Debugw("requesting head from tracked peers", "amount", len(peers)) + } + } + var ( - zero H - trustedPeers = ex.trustedPeers() - headerRespCh = make(chan H, len(trustedPeers)) - headerReq = &p2p_pb.HeaderRequest{ + zero H + headerReq = &p2p_pb.HeaderRequest{ Data: &p2p_pb.HeaderRequest_Origin{Origin: uint64(0)}, Amount: 1, } + headerRespCh = make(chan H, len(peers)) ) - for _, from := range trustedPeers { + for _, from := range peers { go func(from peer.ID) { headers, err := ex.request(reqCtx, from, headerReq) if err != nil { - log.Errorw("head request to trusted peer failed", "trustedPeer", from, "err", err) + log.Errorw("head request to peer failed", "peer", from, "err", err) headerRespCh <- zero return } + // if tracked (untrusted) peers were requested, verify head + if useTrackedPeers { + err = reqParams.TrustedHead.Verify(headers[0]) + if err != nil { + log.Errorw("verifying head received from tracked peer", "tracked peer", from, + "err", err) + // bad head was given, block peer + ex.peerTracker.blockPeer(from, fmt.Errorf("returned bad head: %w", err)) + headerRespCh <- zero + return + } + } // request ensures that the result slice will have at least one Header headerRespCh <- headers[0] }(from) } - headers := make([]H, 0, len(trustedPeers)) - for range trustedPeers { + headers := make([]H, 0, len(peers)) + for range peers { select { case h := <-headerRespCh: if !h.IsZero() { @@ -346,7 +371,7 @@ func bestHead[H header.Header](result []H) (H, error) { // try to find Header with the maximum height that was received at least from 2 peers for _, res := range result { - if counter[res.Hash().String()] >= minTrustedHeadResponses { + if counter[res.Hash().String()] >= minHeadResponses { return res, nil } } diff --git a/p2p/exchange_test.go b/p2p/exchange_test.go index b0fb299c..02f563df 100644 --- a/p2p/exchange_test.go +++ b/p2p/exchange_test.go @@ -2,6 +2,7 @@ package p2p import ( "context" + "strconv" "testing" "time" @@ -18,23 +19,73 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/celestiaorg/go-libp2p-messenger/serde" + "github.com/celestiaorg/go-header" "github.com/celestiaorg/go-header/headertest" p2p_pb "github.com/celestiaorg/go-header/p2p/pb" - "github.com/celestiaorg/go-libp2p-messenger/serde" ) const networkID = "private" func TestExchange_RequestHead(t *testing.T) { - hosts := createMocknet(t, 2) - exchg, store := createP2PExAndServer(t, hosts[0], hosts[1]) - // perform header request - header, err := exchg.Head(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + hosts := createMocknet(t, 3) + exchg, trustedStore := createP2PExAndServer(t, hosts[0], hosts[1]) + + // create new server-side exchange that will act as the tracked peer + // it will have a higher chain head than the trusted peer so that the + // test can determine which peer was asked + trackedStore := headertest.NewStore[*headertest.DummyHeader](t, headertest.NewTestSuite(t), 50) + serverSideEx, err := NewExchangeServer[*headertest.DummyHeader](hosts[2], trackedStore, + WithNetworkID[ServerParameters](networkID), + ) + require.NoError(t, err) + err = serverSideEx.Start(ctx) require.NoError(t, err) + t.Cleanup(func() { + err = serverSideEx.Stop(ctx) + require.NoError(t, err) + }) - assert.Equal(t, store.Headers[store.HeadHeight].Height(), header.Height()) - assert.Equal(t, store.Headers[store.HeadHeight].Hash(), header.Hash()) + tests := []struct { + requestFromTrusted bool + lastHeader header.Header + expectedHeight int64 + expectedHash header.Hash + }{ + // routes to trusted peer only + { + requestFromTrusted: true, + lastHeader: trustedStore.Headers[trustedStore.HeadHeight-1], + expectedHeight: trustedStore.HeadHeight, + expectedHash: trustedStore.Headers[trustedStore.HeadHeight].Hash(), + }, + // routes to tracked peers and takes highest chain head + { + requestFromTrusted: false, + lastHeader: trackedStore.Headers[trackedStore.HeadHeight-1], + expectedHeight: trackedStore.HeadHeight, + expectedHash: trackedStore.Headers[trackedStore.HeadHeight].Hash(), + }, + } + + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var opts []header.HeadOption + if !tt.requestFromTrusted { + opts = append(opts, header.WithTrustedHead(tt.lastHeader)) + } + + header, err := exchg.Head(ctx, opts...) + require.NoError(t, err) + + assert.Equal(t, tt.expectedHeight, header.Height()) + assert.Equal(t, tt.expectedHash, header.Hash()) + }) + } } func TestExchange_RequestHead_UnresponsivePeer(t *testing.T) { @@ -532,7 +583,7 @@ func (t *timedOutStore) HasAt(_ context.Context, _ uint64) bool { return true } -func (t *timedOutStore) Head(_ context.Context) (*headertest.DummyHeader, error) { +func (t *timedOutStore) Head(context.Context, ...header.HeadOption) (*headertest.DummyHeader, error) { time.Sleep(t.timeout) return nil, header.ErrNoHead } diff --git a/p2p/peer_tracker.go b/p2p/peer_tracker.go index b808c363..8981fd7f 100644 --- a/p2p/peer_tracker.go +++ b/p2p/peer_tracker.go @@ -8,7 +8,7 @@ import ( "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" + libpeer "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/p2p/net/conngater" ) @@ -35,10 +35,10 @@ type peerTracker struct { // trackedPeers contains active peers that we can request to. // we cache the peer once they disconnect, // so we can guarantee that peerQueue will only contain active peers - trackedPeers map[peer.ID]*peerStat + trackedPeers map[libpeer.ID]*peerStat // disconnectedPeers contains disconnected peers. In case if peer does not return // online until pruneDeadline, it will be removed and its score will be lost - disconnectedPeers map[peer.ID]*peerStat + disconnectedPeers map[libpeer.ID]*peerStat // an optional interface used to periodically dump // good peers during garbage collection @@ -60,8 +60,8 @@ func newPeerTracker( return &peerTracker{ host: h, connGater: connGater, - trackedPeers: make(map[peer.ID]*peerStat), - disconnectedPeers: make(map[peer.ID]*peerStat), + trackedPeers: make(map[libpeer.ID]*peerStat), + disconnectedPeers: make(map[libpeer.ID]*peerStat), pidstore: pidstore, ctx: ctx, cancel: cancel, @@ -69,6 +69,51 @@ func newPeerTracker( } } +// bootstrap will bootstrap the peerTracker with the given trusted peers and if +// a pidstore was given, will also attempt to bootstrap the tracker with previously +// seen peers. +// +// NOTE: bootstrap is intended to be used with an on-disk peerstore.Peerstore as +// the peerTracker needs access to the previously-seen peers' AddrInfo on start. +func (p *peerTracker) bootstrap(ctx context.Context, trusted []libpeer.ID) error { + // bootstrap connections to trusted + wg := sync.WaitGroup{} + wg.Add(len(trusted)) + defer wg.Wait() + for _, trust := range trusted { + trust := trust + go func() { + defer wg.Done() + p.connectToPeer(ctx, trust) + }() + } + + // short-circuit if pidstore was not provided + if p.pidstore == nil { + return nil + } + + prevSeen, err := p.pidstore.Load(ctx) + if err != nil { + return err + } + + for _, peer := range prevSeen { + go p.connectToPeer(ctx, peer) + } + return nil +} + +// connectToPeer attempts to connect to the given peer. +func (p *peerTracker) connectToPeer(ctx context.Context, peer libpeer.ID) { + err := p.host.Connect(ctx, p.host.Peerstore().PeerInfo(peer)) + if err != nil { + log.Debugw("failed to connect to peer", "id", peer.String(), "err", err) + return + } + log.Debugw("connected to peer", "id", peer.String()) +} + func (p *peerTracker) track() { defer func() { p.done <- struct{}{} @@ -105,7 +150,22 @@ func (p *peerTracker) track() { } } -func (p *peerTracker) connected(pID peer.ID) { +// getPeers returns the tracker's currently tracked peers up to the `max`. +func (p *peerTracker) getPeers(max int) []libpeer.ID { + p.peerLk.RLock() + defer p.peerLk.RUnlock() + + peers := make([]libpeer.ID, 0, max) + for peer := range p.trackedPeers { + peers = append(peers, peer) + if len(peers) == max { + break + } + } + return peers +} + +func (p *peerTracker) connected(pID libpeer.ID) { if p.host.ID() == pID { return } @@ -138,7 +198,7 @@ func (p *peerTracker) connected(pID peer.ID) { p.trackedPeers[pID] = stats } -func (p *peerTracker) disconnected(pID peer.ID) { +func (p *peerTracker) disconnected(pID libpeer.ID) { p.peerLk.Lock() defer p.peerLk.Unlock() stats, ok := p.trackedPeers[pID] @@ -201,7 +261,7 @@ func (p *peerTracker) dumpPeers(ctx context.Context) { return } - peers := make([]peer.ID, 0, len(p.trackedPeers)) + peers := make([]libpeer.ID, 0, len(p.trackedPeers)) p.peerLk.RLock() for id := range p.trackedPeers { @@ -239,7 +299,7 @@ func (p *peerTracker) stop(ctx context.Context) error { } // blockPeer blocks a peer on the networking level and removes it from the local cache. -func (p *peerTracker) blockPeer(pID peer.ID, reason error) { +func (p *peerTracker) blockPeer(pID libpeer.ID, reason error) { // add peer to the blacklist, so we can't connect to it in the future. err := p.connGater.BlockPeer(pID) if err != nil { diff --git a/p2p/peer_tracker_test.go b/p2p/peer_tracker_test.go index 66f9c613..ff5eaf31 100644 --- a/p2p/peer_tracker_test.go +++ b/p2p/peer_tracker_test.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" testpeer "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/p2p/net/conngater" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -73,6 +74,44 @@ func TestPeerTracker_BlockPeer(t *testing.T) { require.True(t, connGater.ListBlockedPeers()[0] == h[1].ID()) } +func TestPeerTracker_Bootstrap(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + connGater, err := conngater.NewBasicConnectionGater(sync.MutexWrap(datastore.NewMapDatastore())) + require.NoError(t, err) + + // mn := createMocknet(t, 10) + mn, err := mocknet.FullMeshConnected(10) + require.NoError(t, err) + + // store peers to peerstore + prevSeen := make([]peer.ID, 9) + for i, peer := range mn.Hosts()[1:] { + prevSeen[i] = peer.ID() + + // disconnect so they're not already connected on attempt to + // connect + err = mn.DisconnectPeers(mn.Hosts()[i].ID(), peer.ID()) + require.NoError(t, err) + } + pidstore := newDummyPIDStore() + // only store 7 peers to pidstore, and use 2 as trusted + err = pidstore.Put(ctx, prevSeen[2:]) + require.NoError(t, err) + + tracker := newPeerTracker(mn.Hosts()[0], connGater, pidstore) + + go tracker.track() + + err = tracker.bootstrap(ctx, prevSeen[:2]) + require.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(tracker.getPeers(7)) > 0 + }, time.Millisecond*500, time.Millisecond*100) +} + type dummyPIDStore struct { ds datastore.Datastore key datastore.Key diff --git a/store/store.go b/store/store.go index fbbe8dca..662dc43a 100644 --- a/store/store.go +++ b/store/store.go @@ -158,7 +158,7 @@ func (s *Store[H]) Height() uint64 { return s.heightSub.Height() } -func (s *Store[H]) Head(ctx context.Context) (H, error) { +func (s *Store[H]) Head(ctx context.Context, _ ...header.HeadOption) (H, error) { head, err := s.GetByHeight(ctx, s.heightSub.Height()) if err == nil { return head, nil diff --git a/sync/sync_getter.go b/sync/sync_getter.go index b25b2ac1..2a14e0ed 100644 --- a/sync/sync_getter.go +++ b/sync/sync_getter.go @@ -39,9 +39,9 @@ func (sg *syncGetter[H]) Unlock() { } // Head must be called with held Lock. -func (sg *syncGetter[H]) Head(ctx context.Context) (H, error) { +func (sg *syncGetter[H]) Head(ctx context.Context, opts ...header.HeadOption) (H, error) { sg.checkLock("Head without preceding Lock on syncGetter") - return sg.Getter.Head(ctx) + return sg.Getter.Head(ctx, opts...) } // checkLock ensures api safety diff --git a/sync/sync_getter_test.go b/sync/sync_getter_test.go index 06fcd7ad..e92dd555 100644 --- a/sync/sync_getter_test.go +++ b/sync/sync_getter_test.go @@ -47,7 +47,7 @@ type fakeGetter[H header.Header] struct { hits atomic.Uint32 } -func (f *fakeGetter[H]) Head(ctx context.Context) (h H, err error) { +func (f *fakeGetter[H]) Head(ctx context.Context, _ ...header.HeadOption) (h H, err error) { f.hits.Add(1) select { case <-time.After(time.Millisecond * 100): diff --git a/sync/sync_head.go b/sync/sync_head.go index d4b35dbc..1e0e88b4 100644 --- a/sync/sync_head.go +++ b/sync/sync_head.go @@ -15,7 +15,7 @@ import ( // Known subjective head is considered network head if it is recent enough(now-timestamp<=blocktime) // Otherwise, head is requested from a trusted peer and // set as the new subjective head, assuming that trusted peer is always fully synced. -func (s *Syncer[H]) Head(ctx context.Context) (H, error) { +func (s *Syncer[H]) Head(ctx context.Context, _ ...header.HeadOption) (H, error) { sbjHead, err := s.subjectiveHead(ctx) if err != nil { return sbjHead, err @@ -24,7 +24,7 @@ func (s *Syncer[H]) Head(ctx context.Context) (H, error) { if isRecent(sbjHead, s.Params.blockTime) { return sbjHead, nil } - // otherwise, request head from a trusted peer, as we assume it is fully synced + // otherwise, request head from the network // // TODO(@Wondertan): Here is another potential networking optimization: // * From sbjHead's timestamp and current time predict the time to the next header(TNH) @@ -40,7 +40,7 @@ func (s *Syncer[H]) Head(ctx context.Context) (H, error) { return s.Head(ctx) } defer s.getter.Unlock() - netHead, err := s.getter.Head(ctx) + netHead, err := s.getter.Head(ctx, header.WithTrustedHead(sbjHead)) if err != nil { log.Warnw("failed to return head from trusted peer, returning subjective head which may not be recent", "sbjHead", sbjHead.Height(), "err", err) return sbjHead, nil diff --git a/sync/sync_head_test.go b/sync/sync_head_test.go index 6d464ca4..715d8dcd 100644 --- a/sync/sync_head_test.go +++ b/sync/sync_head_test.go @@ -7,11 +7,16 @@ import ( "testing" "time" + "github.com/ipfs/go-datastore" + sync2 "github.com/ipfs/go-datastore/sync" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/celestiaorg/go-header" "github.com/celestiaorg/go-header/headertest" + "github.com/celestiaorg/go-header/local" + "github.com/celestiaorg/go-header/store" ) func TestSyncer_incomingNetworkHeadRaces(t *testing.T) { @@ -19,6 +24,7 @@ func TestSyncer_incomingNetworkHeadRaces(t *testing.T) { t.Cleanup(cancel) suite := headertest.NewTestSuite(t) + store := headertest.NewStore[*headertest.DummyHeader](t, suite, 1) syncer, err := NewSyncer[*headertest.DummyHeader]( store, @@ -43,4 +49,99 @@ func TestSyncer_incomingNetworkHeadRaces(t *testing.T) { wg.Wait() assert.EqualValues(t, 1, hits.Load()) + +} + +// TestSyncer_HeadWithTrustedHead tests whether the syncer +// requests Head (new sync target) from tracked peers when +// it already has a subjective head within the unbonding period. +func TestSyncer_HeadWithTrustedHead(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + head := suite.Head() + + localStore := store.NewTestStore(ctx, t, head) + + remoteStore, err := store.NewStoreWithHead(ctx, sync2.MutexWrap(datastore.NewMapDatastore()), head) + require.NoError(t, err) + err = remoteStore.Append(ctx, suite.GenDummyHeaders(100)...) + require.NoError(t, err) + + // create a wrappedGetter to track exchange interactions + wrappedGetter := newWrappedGetter(local.NewExchange[*headertest.DummyHeader](remoteStore)) + + syncer, err := NewSyncer[*headertest.DummyHeader]( + wrappedGetter, + localStore, + headertest.NewDummySubscriber(), + // forces a request for a new sync target + WithBlockTime(time.Nanosecond), + // ensures that syncer's store contains a subjective head that is within + // the unbonding period so that the syncer can use a header from the network + // as a sync target + WithTrustingPeriod(time.Hour), + ) + require.NoError(t, err) + + // start the syncer which triggers a Head request that will + // load the syncer's subjective head from the store, and request + // a new sync target from the network rather than from trusted peers + err = syncer.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { + err = syncer.Stop(ctx) + require.NoError(t, err) + }) + + // ensure the syncer really requested Head from the network + // rather than from trusted peers + require.True(t, wrappedGetter.withTrustedHead) +} + +type wrappedGetter struct { + ex header.Exchange[*headertest.DummyHeader] + + // withTrustedHead indicates whether TrustedHead was set by the request + // via the WithTrustedHead opt. + withTrustedHead bool +} + +func newWrappedGetter(ex header.Exchange[*headertest.DummyHeader]) *wrappedGetter { + return &wrappedGetter{ + ex: ex, + withTrustedHead: false, + } +} + +func (t *wrappedGetter) Head(ctx context.Context, options ...header.HeadOption) (*headertest.DummyHeader, error) { + params := header.HeadParams{} + for _, opt := range options { + opt(¶ms) + } + if params.TrustedHead != nil { + t.withTrustedHead = true + } + return t.ex.Head(ctx, options...) +} + +func (t *wrappedGetter) Get(ctx context.Context, hash header.Hash) (*headertest.DummyHeader, error) { + //TODO implement me + panic("implement me") +} + +func (t *wrappedGetter) GetByHeight(ctx context.Context, u uint64) (*headertest.DummyHeader, error) { + //TODO implement me + panic("implement me") +} + +func (t *wrappedGetter) GetRangeByHeight(ctx context.Context, from, amount uint64) ([]*headertest.DummyHeader, error) { + //TODO implement me + panic("implement me") +} + +func (t *wrappedGetter) GetVerifiedRange(ctx context.Context, from *headertest.DummyHeader, amount uint64) ([]*headertest.DummyHeader, error) { + //TODO implement me + panic("implement me") } diff --git a/sync/sync_test.go b/sync/sync_test.go index 06f486c3..cebf32d7 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -306,7 +306,6 @@ func TestSync_InvalidSyncTarget(t *testing.T) { local.NewExchange[*headertest.DummyHeader](remoteStore), localStore, headertest.NewDummySubscriber(), - WithTrustingPeriod(time.Second), WithBlockTime(time.Nanosecond), // force syncer to request more recent sync target ) require.NoError(t, err)