From b8065f953dcd5660848e96ab03e3e9eab6108194 Mon Sep 17 00:00:00 2001 From: Wondertan Date: Mon, 31 Jul 2023 12:21:13 +0200 Subject: [PATCH] refactor: harden verification --- errors.go | 20 ------ headertest/dummy_header.go | 15 +--- headertest/dummy_suite.go | 2 + p2p/exchange_test.go | 7 +- p2p/subscriber.go | 5 +- store/store.go | 3 +- sync/sync_head.go | 68 +++++++++--------- sync/sync_test.go | 7 +- sync/verify/verify.go | 112 ++++++++++++++++++++++++++++++ sync/verify/verify_test.go | 139 +++++++++++++++++++++++++++++++++++++ 10 files changed, 300 insertions(+), 78 deletions(-) delete mode 100644 errors.go create mode 100644 sync/verify/verify.go create mode 100644 sync/verify/verify_test.go diff --git a/errors.go b/errors.go deleted file mode 100644 index 127499c7..00000000 --- a/errors.go +++ /dev/null @@ -1,20 +0,0 @@ -package header - -import "fmt" - -// VerifyError is thrown during for Headers failed verification. -type VerifyError struct { - // Reason why verification failed as inner error. - Reason error - // Uncertain signals that the was not enough information to conclude a Header is correct or not. - // May happen with recent Headers during unfinished historical sync or because of local errors. - Uncertain bool -} - -func (vr *VerifyError) Error() string { - return fmt.Sprintf("header: verify: %s", vr.Reason.Error()) -} - -func (vr *VerifyError) Unwrap() error { - return vr.Reason -} diff --git a/headertest/dummy_header.go b/headertest/dummy_header.go index 39f932f0..20fda64d 100644 --- a/headertest/dummy_header.go +++ b/headertest/dummy_header.go @@ -61,7 +61,7 @@ func (d *DummyHeader) IsZero() bool { } func (d *DummyHeader) ChainID() string { - return "private" + return d.Raw.ChainID } func (d *DummyHeader) Hash() header.Hash { @@ -109,19 +109,6 @@ func (d *DummyHeader) Verify(header header.Header) error { return fmt.Errorf("header at height %d failed verification", header.Height()) } - epsilon := 10 * time.Second - if header.Time().After(time.Now().Add(epsilon)) { - return fmt.Errorf("header Time too far in the future") - } - - if header.Height() <= d.Height() { - return fmt.Errorf("expected new header Height %d to be larger than old header Height %d", header.Height(), d.Height()) - } - - if header.Time().Before(d.Time()) { - return fmt.Errorf("expected new header Time %v to be after old header Time %v", header.Time(), d.Time()) - } - return nil } diff --git a/headertest/dummy_suite.go b/headertest/dummy_suite.go index f5a81bee..1b0ad9f7 100644 --- a/headertest/dummy_suite.go +++ b/headertest/dummy_suite.go @@ -45,6 +45,7 @@ func (s *DummySuite) NextHeader() *DummyHeader { dh.Raw.Time = s.head.Time().Add(time.Nanosecond) dh.Raw.Height = s.head.Height() + 1 dh.Raw.PreviousHash = s.head.Hash() + dh.Raw.ChainID = s.head.ChainID() _ = dh.rehash() s.head = dh return s.head @@ -57,6 +58,7 @@ func (s *DummySuite) genesis() *DummyHeader { PreviousHash: nil, Height: 1, Time: time.Now().Add(-10 * time.Second).UTC(), + ChainID: "test", }, } } diff --git a/p2p/exchange_test.go b/p2p/exchange_test.go index 02f563df..71cee81e 100644 --- a/p2p/exchange_test.go +++ b/p2p/exchange_test.go @@ -19,14 +19,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/celestiaorg/go-libp2p-messenger/serde" - "github.com/celestiaorg/go-header" "github.com/celestiaorg/go-header/headertest" p2p_pb "github.com/celestiaorg/go-header/p2p/pb" + "github.com/celestiaorg/go-libp2p-messenger/serde" ) -const networkID = "private" +const networkID = "test" // must match the chain-id in test suite func TestExchange_RequestHead(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -393,7 +392,7 @@ func TestExchange_RequestByHashFails(t *testing.T) { func TestExchange_HandleHeaderWithDifferentChainID(t *testing.T) { hosts := createMocknet(t, 2) exchg, store := createP2PExAndServer(t, hosts[0], hosts[1]) - exchg.Params.chainID = "test" + exchg.Params.chainID = "test1" _, err := exchg.Head(context.Background()) require.Error(t, err) diff --git a/p2p/subscriber.go b/p2p/subscriber.go index d9ad88f6..8f89caa6 100644 --- a/p2p/subscriber.go +++ b/p2p/subscriber.go @@ -9,6 +9,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/celestiaorg/go-header" + "github.com/celestiaorg/go-header/sync/verify" ) // Subscriber manages the lifecycle and relationship of header Module @@ -77,12 +78,12 @@ func (p *Subscriber[H]) SetVerifier(val func(context.Context, H) error) error { // additional unmarhalling msg.ValidatorData = hdr - var verErr *header.VerifyError + var verErr *verify.VerifyError err = val(ctx, hdr) switch { case err == nil: return pubsub.ValidationAccept - case errors.As(err, &verErr) && verErr.Uncertain: + case errors.As(err, &verErr) && verErr.SoftFailure: return pubsub.ValidationIgnore default: return pubsub.ValidationReject diff --git a/store/store.go b/store/store.go index 662dc43a..c7c94f9e 100644 --- a/store/store.go +++ b/store/store.go @@ -12,6 +12,7 @@ import ( logging "github.com/ipfs/go-log/v2" "github.com/celestiaorg/go-header" + "github.com/celestiaorg/go-header/sync/verify" ) var log = logging.Logger("header/store") @@ -333,7 +334,7 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { err = head.Verify(h) if err != nil { - var verErr *header.VerifyError + var verErr *verify.VerifyError if errors.As(err, &verErr) { log.Errorw("invalid header", "height_of_head", head.Height(), diff --git a/sync/sync_head.go b/sync/sync_head.go index 1dc43639..707cdfa0 100644 --- a/sync/sync_head.go +++ b/sync/sync_head.go @@ -6,6 +6,7 @@ import ( "time" "github.com/celestiaorg/go-header" + "github.com/celestiaorg/go-header/sync/verify" ) // Head returns the Network Head. @@ -40,8 +41,8 @@ func (s *Syncer[H]) Head(ctx context.Context, _ ...header.HeadOption) (H, error) defer s.getter.Unlock() netHead, err := s.getter.Head(ctx, header.WithTrustedHead(sbjHead)) if err != nil { - log.Warnw("failed to return head from trusted peer, returning subjective head which may not be recent", "sbjHead", sbjHead.Height(), "err", err) - return sbjHead, nil + log.Warnw("failed to get recent head, returning current subjective", "sbjHead", sbjHead.Height(), "err", err) + return s.subjectiveHead(ctx) } // process and validate netHead fetched from trusted peers // NOTE: We could trust the netHead like we do during 'automatic subjective initialization' @@ -134,54 +135,53 @@ func (s *Syncer[H]) setSubjectiveHead(ctx context.Context, netHead H) { // incomingNetworkHead processes new potential network headers. // If the header valid, sets as new subjective header. -func (s *Syncer[H]) incomingNetworkHead(ctx context.Context, netHead H) error { +func (s *Syncer[H]) incomingNetworkHead(ctx context.Context, head H) error { // ensure there is no racing between network head candidates s.incomingMu.Lock() defer s.incomingMu.Unlock() - // first of all, check the validity of the netHead - err := s.validateHead(ctx, netHead) - if err != nil { + + softFailure, err := s.verify(ctx, head) + if err != nil && !softFailure { return err } - // and set it if valid - s.setSubjectiveHead(ctx, netHead) - return nil + + // TODO(@Wondertan): + // Implement setSyncTarget and use it for soft failures + s.setSubjectiveHead(ctx, head) + return err } -// validateHead checks validity of the given header against the subjective head. -func (s *Syncer[H]) validateHead(ctx context.Context, new H) error { +// verify verifies given network head candidate. +func (s *Syncer[H]) verify(ctx context.Context, newHead H) (bool, error) { sbjHead, err := s.subjectiveHead(ctx) if err != nil { log.Errorw("getting subjective head during validation", "err", err) - // local error, so uncertain - return &header.VerifyError{Reason: err, Uncertain: true} - } - if new.Height() <= sbjHead.Height() { - log.Warnw("received known network header", - "current_height", sbjHead.Height(), - "header_height", new.Height(), - "header_hash", new.Hash()) - // set uncertain, if it's from the past - return &header.VerifyError{Reason: err, Uncertain: true} - } - // perform verification - err = sbjHead.Verify(new) - var verErr *header.VerifyError - if errors.As(err, &verErr) { + // local error, so soft + return true, &verify.VerifyError{Reason: err, SoftFailure: true} + } + + var heightThreshold int64 + if s.Params.TrustingPeriod != 0 && s.Params.blockTime != 0 { + heightThreshold = int64(s.Params.TrustingPeriod / s.Params.blockTime) + } + + err = verify.Verify(sbjHead, newHead, heightThreshold) + if err == nil { + return false, nil + } + + var verErr *verify.VerifyError + if errors.As(err, &verErr) && !verErr.SoftFailure { log.Errorw("invalid network header", - "height_of_invalid", new.Height(), - "hash_of_invalid", new.Hash(), + "height_of_invalid", newHead.Height(), + "hash_of_invalid", newHead.Hash(), "height_of_subjective", sbjHead.Height(), "hash_of_subjective", sbjHead.Hash(), "reason", verErr.Reason) - return verErr } - // and accept if the header is good - return nil -} -// TODO(@Wondertan): We should request TrustingPeriod from the network's state params or -// listen for network params changes to always have a topical value. + return verErr.SoftFailure, err +} // isExpired checks if header is expired against trusting period. func isExpired(header header.Header, period time.Duration) bool { diff --git a/sync/sync_test.go b/sync/sync_test.go index ba8adb9e..470b8367 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -12,6 +12,7 @@ import ( "github.com/celestiaorg/go-header/headertest" "github.com/celestiaorg/go-header/local" "github.com/celestiaorg/go-header/store" + "github.com/celestiaorg/go-header/sync/verify" ) func TestSyncSimpleRequestingHead(t *testing.T) { @@ -277,10 +278,9 @@ func TestSyncerIncomingDuplicate(t *testing.T) { time.Sleep(time.Millisecond * 10) - var verErr *header.VerifyError + var verErr *verify.VerifyError err = syncer.incomingNetworkHead(ctx, range1[len(range1)-1]) assert.ErrorAs(t, err, &verErr) - assert.True(t, verErr.Uncertain) err = syncer.SyncWait(ctx) require.NoError(t, err) @@ -356,7 +356,8 @@ func TestSync_InvalidSyncTarget(t *testing.T) { // a new sync job to a good sync target expectedHead, err := remoteStore.Head(ctx) require.NoError(t, err) - syncer.incomingNetworkHead(ctx, expectedHead) + err = syncer.incomingNetworkHead(ctx, expectedHead) + require.NoError(t, err) // wait for syncer to finish (give it a bit of time to register // new job with new sync target) diff --git a/sync/verify/verify.go b/sync/verify/verify.go new file mode 100644 index 00000000..32255def --- /dev/null +++ b/sync/verify/verify.go @@ -0,0 +1,112 @@ +// TODO(@Wondertan): Should be just part of sync pkg and not subpkg +// +// Fix after adjacency requirement is removed from the Store. +package verify + +import ( + "errors" + "fmt" + "time" + + "github.com/celestiaorg/go-header" +) + +// DefaultHeightThreshold defines default height threshold beyond which headers are rejected +// NOTE: Compared against subjective head which is guaranteed to be non-expired +const DefaultHeightThreshold int64 = 40000 // ~ 7 days of 15 second headers + +// VerifyError is thrown during for Headers failed verification. +type VerifyError struct { + // Reason why verification failed as inner error. + Reason error + // SoftFailure means verification did not have enough information to definitively conclude a + // Header was correct or not. + // May happen with recent Headers during unfinished historical sync or because of local errors. + // TODO(@Wondertan): Better be part of signature Header.Verify() (bool, error), but kept here + // not to break + SoftFailure bool +} + +func (vr *VerifyError) Error() string { + return fmt.Sprintf("header: verify: %s", vr.Reason.Error()) +} + +func (vr *VerifyError) Unwrap() error { + return vr.Reason +} + +// Verify verifies untrusted Header against trusted following general Header checks and +// custom user-specific checks defined in Header.Verify +// +// If heightThreshold is zero, uses DefaultHeightThreshold. +// Always returns VerifyError. +func Verify[H header.Header](trstd, untrstd H, heightThreshold int64) error { + // general mandatory verification + err := verify[H](trstd, untrstd, heightThreshold) + if err != nil { + return &VerifyError{Reason: err} + } + // user defined verification + err = trstd.Verify(untrstd) + if err == nil { + return nil + } + // if that's an error, ensure we always return VerifyError + var verErr *VerifyError + if !errors.As(err, &verErr) { + verErr = &VerifyError{Reason: err} + } + // check adjacency of failed verification + adjacent := untrstd.Height() == trstd.Height()+1 + if !adjacent { + // if non-adjacent, we don't know if the header is *really* wrong + // so set as soft + verErr.SoftFailure = true + } + // we trust adjacent verification to it's fullest + // if verification fails - the header is *really* wrong + return verErr +} + +// verify is a little bro of Verify yet performs mandatory Header checks +// for any Header implementation. +func verify[H header.Header](trstd, untrstd H, heightThreshold int64) error { + if heightThreshold == 0 { + heightThreshold = DefaultHeightThreshold + } + + if untrstd.IsZero() { + return fmt.Errorf("zero header") + } + + if untrstd.ChainID() != trstd.ChainID() { + return fmt.Errorf("wrong header chain id %s, not %s", untrstd.ChainID(), trstd.ChainID()) + } + + if !untrstd.Time().After(trstd.Time()) { + return fmt.Errorf("unordered header timestamp %v is before %v", untrstd.Time(), trstd.Time()) + } + + now := time.Now() + if !untrstd.Time().Before(now.Add(clockDrift)) { + return fmt.Errorf("header timestamp %v is from future (now: %v, clock_drift: %v)", untrstd.Time(), now, clockDrift) + } + + known := untrstd.Height() <= trstd.Height() + if known { + return fmt.Errorf("known header height %d, current %d", untrstd.Height(), trstd.Height()) + } + // reject headers with height too far from the future + // this is essential for headers failed non-adjacent verification + // yet taken as sync target + adequateHeight := untrstd.Height()-trstd.Height() < heightThreshold + if !adequateHeight { + return fmt.Errorf("header height %d is far from future (current: %d, threshold: %d)", untrstd.Height(), trstd.Height(), heightThreshold) + } + + return nil +} + +// clockDrift defines how much new header's time can drift into +// the future relative to the now time during verification. +var clockDrift = 10 * time.Second diff --git a/sync/verify/verify_test.go b/sync/verify/verify_test.go new file mode 100644 index 00000000..954fd73c --- /dev/null +++ b/sync/verify/verify_test.go @@ -0,0 +1,139 @@ +package verify + +import ( + "errors" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/celestiaorg/go-header/headertest" +) + +func TestVerify(t *testing.T) { + suite := headertest.NewTestSuite(t) + trusted := suite.GenDummyHeaders(1)[0] + + tests := []struct { + prepare func() *headertest.DummyHeader + err bool + soft bool + }{ + { + prepare: func() *headertest.DummyHeader { + return nil + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.VerifyFailure = true + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.VerifyFailure = true + return untrusted + }, + err: true, + soft: true, // soft because non-adjacent + }, + { + prepare: func() *headertest.DummyHeader { + return suite.NextHeader() + }, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + err := Verify(trusted, test.prepare(), 0) + if test.err { + var verErr *VerifyError + assert.ErrorAs(t, err, &verErr) + assert.NotNil(t, errors.Unwrap(verErr)) + assert.Equal(t, test.soft, verErr.SoftFailure) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_verify(t *testing.T) { + suite := headertest.NewTestSuite(t) + trusted := suite.GenDummyHeaders(1)[0] + + tests := []struct { + prepare func() *headertest.DummyHeader + err bool + }{ + { + prepare: func() *headertest.DummyHeader { + return suite.NextHeader() + }, + }, + { + prepare: func() *headertest.DummyHeader { + return nil + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.ChainID = "gtmb" + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Time = untrusted.Raw.Time.Truncate(time.Minute * 10) + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Time = untrusted.Raw.Time.Add(time.Minute) + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Height = trusted.Height() + return untrusted + }, + err: true, + }, + { + prepare: func() *headertest.DummyHeader { + untrusted := suite.NextHeader() + untrusted.Raw.Height += 100000 + return untrusted + }, + err: true, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + err := verify(trusted, test.prepare(), 0) + if test.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +}