diff --git a/.mockery.yaml b/.mockery.yaml index d44c09dfc8..ceed741a1c 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -53,6 +53,7 @@ packages: github.com/dashpay/tenderdash/internal/statesync: interfaces: StateProvider: + ConsensusStateProvider: github.com/dashpay/tenderdash/libs/store: interfaces: Store: diff --git a/config/config.go b/config/config.go index ac43d04413..a6e4903fa9 100644 --- a/config/config.go +++ b/config/config.go @@ -1,7 +1,6 @@ package config import ( - "encoding/hex" "encoding/json" "errors" "fmt" @@ -1001,18 +1000,16 @@ type StateSyncConfig struct { // with net.Dial, for example: "host.example.com:2125". RPCServers []string `mapstructure:"rpc-servers"` - // The hash and height of a trusted block. Must be within the trust-period. - TrustHeight int64 `mapstructure:"trust-height"` - TrustHash string `mapstructure:"trust-hash"` - - // The trust period should be set so that Tendermint can detect and gossip - // misbehavior before it is considered expired. For chains based on the Cosmos SDK, - // one day less than the unbonding period should suffice. - TrustPeriod time.Duration `mapstructure:"trust-period"` - // Time to spend discovering snapshots before initiating a restore. DiscoveryTime time.Duration `mapstructure:"discovery-time"` + // Number of times to retry state sync. When retries are exhausted, the node will + // fall back to the regular block sync. Set to 0 to disable retries. Default is 3. + // + // Note that in pessimistic case, it will take at least `discovery-time * retries` before + // falling back to block sync. + Retries int `mapstructure:"retries"` + // Temporary directory for state sync snapshot chunks, defaults to os.TempDir(). // The synchronizer will create a new, randomly named directory within this directory // and remove it when the sync is complete. @@ -1026,22 +1023,13 @@ type StateSyncConfig struct { Fetchers int `mapstructure:"fetchers"` } -func (cfg *StateSyncConfig) TrustHashBytes() []byte { - // validated in ValidateBasic, so we can safely panic here - bytes, err := hex.DecodeString(cfg.TrustHash) - if err != nil { - panic(err) - } - return bytes -} - // DefaultStateSyncConfig returns a default configuration for the state sync service func DefaultStateSyncConfig() *StateSyncConfig { return &StateSyncConfig{ - TrustPeriod: 168 * time.Hour, DiscoveryTime: 15 * time.Second, ChunkRequestTimeout: 15 * time.Second, Fetchers: 4, + Retries: 3, } } @@ -1074,21 +1062,8 @@ func (cfg *StateSyncConfig) ValidateBasic() error { return errors.New("discovery time must be 0s or greater than five seconds") } - if cfg.TrustPeriod <= 0 { - return errors.New("trusted-period is required") - } - - if cfg.TrustHeight <= 0 { - return errors.New("trusted-height is required") - } - - if len(cfg.TrustHash) == 0 { - return errors.New("trusted-hash is required") - } - - _, err := hex.DecodeString(cfg.TrustHash) - if err != nil { - return fmt.Errorf("invalid trusted-hash: %w", err) + if cfg.Retries < 0 { + return errors.New("retries must be greater than or equal to zero") } if cfg.ChunkRequestTimeout < 5*time.Second { diff --git a/config/toml.go b/config/toml.go index b6be127b28..dddf0ac43f 100644 --- a/config/toml.go +++ b/config/toml.go @@ -495,18 +495,15 @@ use-p2p = {{ .StateSync.UseP2P }} # for example: "host.example.com:2125" rpc-servers = "{{ StringsJoin .StateSync.RPCServers "," }}" -# The hash and height of a trusted block. Must be within the trust-period. -trust-height = {{ .StateSync.TrustHeight }} -trust-hash = "{{ .StateSync.TrustHash }}" - -# The trust period should be set so that Tendermint can detect and gossip misbehavior before -# it is considered expired. For chains based on the Cosmos SDK, one day less than the unbonding -# period should suffice. -trust-period = "{{ .StateSync.TrustPeriod }}" - # Time to spend discovering snapshots before initiating a restore. discovery-time = "{{ .StateSync.DiscoveryTime }}" +# Number of times to retry state sync. When retries are exhausted, the node will +# fall back to the regular block sync. Set to 0 to disable retries. Default is 3. +# Note that in pessimistic case, it will take at least (discovery-time * retries) before +# falling back to block sync. +retries = {{ .StateSync.Retries }} + # Temporary directory for state sync snapshot chunks, defaults to os.TempDir(). # The synchronizer will create a new, randomly named directory within this directory # and remove it when the sync is complete. diff --git a/go.mod b/go.mod index 4733a3bc05..3bde02eab9 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,8 @@ require ( github.com/btcsuite/btcutil v1.0.3-0.20201208143702-a53e38424cce github.com/containerd/continuity v0.4.5 // indirect github.com/dashpay/bls-signatures/go-bindings v0.0.0-20230207105415-06df92693ac8 - github.com/dashpay/dashd-go v0.25.0 - github.com/dashpay/dashd-go/btcec/v2 v2.1.0 // indirect + github.com/dashpay/dashd-go v0.26.1 + github.com/dashpay/dashd-go/btcec/v2 v2.2.0 // indirect github.com/fortytw2/leaktest v1.3.0 github.com/fxamacker/cbor/v2 v2.8.0 github.com/go-kit/kit v0.13.0 @@ -80,7 +80,7 @@ require ( github.com/alingse/nilnesserr v0.1.2 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/bombsimon/wsl/v4 v4.5.0 // indirect - github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f // indirect + github.com/btcsuite/btclog v1.0.0 // indirect github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd // indirect github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 // indirect github.com/bufbuild/protocompile v0.14.1 // indirect @@ -101,7 +101,7 @@ require ( github.com/containerd/log v0.1.0 // indirect github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect github.com/curioswitch/go-reassign v0.3.0 // indirect - github.com/dashpay/dashd-go/btcutil v1.2.0 // indirect + github.com/dashpay/dashd-go/btcutil v1.3.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect github.com/dgraph-io/badger/v4 v4.5.1 // indirect github.com/dgraph-io/ristretto/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 83561afc61..3aa2b1eaee 100644 --- a/go.sum +++ b/go.sum @@ -106,8 +106,9 @@ github.com/btcsuite/btcd v0.22.1 h1:CnwP9LM/M9xuRrGSCGeMVs9iv09uMqwsVX7EeIpgV2c= github.com/btcsuite/btcd v0.22.1/go.mod h1:wqgTSL29+50LRkmOVknEdmt8ZojIzhuWvgu/iptuN7Y= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 h1:q0rUy8C/TYNBQS1+CGKw68tLOFYSNEs0TFnxxnS9+4U= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= -github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f h1:bAs4lUbRJpnnkd9VhRV3jjAVU7DJVjMaK+IsvSeZvFo= github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f/go.mod h1:TdznJufoqS23FtqVCzL0ZqgP5MqXbb4fg/WgDys70nA= +github.com/btcsuite/btclog v1.0.0 h1:sEkpKJMmfGiyZjADwEIgB1NSwMyfdD1FB8v6+w1T0Ns= +github.com/btcsuite/btclog v1.0.0/go.mod h1:w7xnGOhwT3lmrS4H3b/D1XAXxvh+tbhUm8xeHN2y3TQ= github.com/btcsuite/btcutil v0.0.0-20190425235716-9e5f4b9a998d/go.mod h1:+5NJ2+qvTyV9exUAL/rxXi3DcLg2Ts+ymUAY5y4NvMg= github.com/btcsuite/btcutil v1.0.3-0.20201208143702-a53e38424cce h1:YtWJF7RHm2pYCvA5t0RPmAaLUhREsKuKd+SLhxFbFeQ= github.com/btcsuite/btcutil v1.0.3-0.20201208143702-a53e38424cce/go.mod h1:0DVlHczLPewLcPGEIeUEzfOJhqGPQ0mJJRDBtD307+o= @@ -206,12 +207,12 @@ github.com/daixiang0/gci v0.13.5 h1:kThgmH1yBmZSBCh1EJVxQ7JsHpm5Oms0AMed/0LaH4c= github.com/daixiang0/gci v0.13.5/go.mod h1:12etP2OniiIdP4q+kjUGrC/rUagga7ODbqsom5Eo5Yk= github.com/dashpay/bls-signatures/go-bindings v0.0.0-20230207105415-06df92693ac8 h1:v4K3CiDoFY1gjcWL/scRcwzyjBwh8TVG3ek8cWolK1g= github.com/dashpay/bls-signatures/go-bindings v0.0.0-20230207105415-06df92693ac8/go.mod h1:auvGS60NBZ+a21aCCQh366PdsjDvHinsCvl28VrYPu4= -github.com/dashpay/dashd-go v0.25.0 h1:tswVRmM2fLHC/JhpuAZ5Oa0TpOO6L+tqiE+QLTCvIQc= -github.com/dashpay/dashd-go v0.25.0/go.mod h1:4yuk/laGME2RnQRTdqTbw87PhT+42hE1anLCnpkgls8= -github.com/dashpay/dashd-go/btcec/v2 v2.1.0 h1:fXwlLf5H+TtgHxjGMU74NesKzk6NisjKMPF04pBcygk= -github.com/dashpay/dashd-go/btcec/v2 v2.1.0/go.mod h1:1i8XtxdOmvK6mYEUCneVXTzFbrCUw3wq1u91j8gvsns= -github.com/dashpay/dashd-go/btcutil v1.2.0 h1:YMq7L0V0au5bbphIhpsBBc+nfOZqU+gJ4pkgRZB7Eiw= -github.com/dashpay/dashd-go/btcutil v1.2.0/go.mod h1:7UHoqUh3LY3OI4mEcogx0CnL3rtzDQyoqvsOCZZtvzE= +github.com/dashpay/dashd-go v0.26.1 h1:/ZFgtPw1fPHpvoJgKfXo/v63ZXddjJm8KrHRpxcSpy0= +github.com/dashpay/dashd-go v0.26.1/go.mod h1:7KKS2jSPkC1pTz9WLXpiXZ96wT5bUqKTRuk35AyRQ74= +github.com/dashpay/dashd-go/btcec/v2 v2.2.0 h1:tk54BC++OvOUu0vcPoG8+45dGoJXKsmupYAawBO/1Vk= +github.com/dashpay/dashd-go/btcec/v2 v2.2.0/go.mod h1:uOmCM/hVoJ1x6w+3SX+zQv+2LdrK3aO59RV41jNvTF4= +github.com/dashpay/dashd-go/btcutil v1.3.0 h1:yDX8tz7C/KhFHbGlRXBpNN+zlkmAgwkICD9DlAv/Vsc= +github.com/dashpay/dashd-go/btcutil v1.3.0/go.mod h1:sMWZ0iR8a/wmIA6b5+ccjOGUfq+iZvi5t6ECaLCW+kw= github.com/davecgh/go-spew v0.0.0-20171005155431-ecdeabc65495/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/blocksync/reactor.go b/internal/blocksync/reactor.go index 822550f165..b182229133 100644 --- a/internal/blocksync/reactor.go +++ b/internal/blocksync/reactor.go @@ -159,7 +159,7 @@ func (r *Reactor) OnStop() { // processPeerUpdate processes a PeerUpdate. func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate, client *client.Client) { - r.logger.Debug("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + r.logger.Trace("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) // XXX: Pool#RedoRequest can sometimes give us an empty peer. if len(peerUpdate.NodeID) == 0 { diff --git a/internal/consensus/state.go b/internal/consensus/state.go index 02180c3305..69a6dcc6c4 100644 --- a/internal/consensus/state.go +++ b/internal/consensus/state.go @@ -307,6 +307,10 @@ func (cs *State) SetProposedAppVersion(ver uint64) { cs.emitter.Emit(setProposedAppVersionEventName, ver) } +func (cs *State) GetCurrentHeight() int64 { + return cs.stateDataStore.Get().Height +} + func (cs *State) updateStateFromStore() error { state, err := cs.stateStore.Load() if err != nil { diff --git a/internal/consensus/state_data.go b/internal/consensus/state_data.go index 0686b0f32d..66b0405711 100644 --- a/internal/consensus/state_data.go +++ b/internal/consensus/state_data.go @@ -119,7 +119,7 @@ func (s *StateDataStore) Subscribe(evsw *eventemitter.EventEmitter) { }) } -// StateData is a copy of the current RoundState nad state.State stored in the store +// StateData is a copy of the current RoundState and state.State stored in the store // Along with data, StateData provides some methods to check or update data inside type StateData struct { config *config.ConsensusConfig diff --git a/internal/evidence/reactor.go b/internal/evidence/reactor.go index d0ba641310..75689553e5 100644 --- a/internal/evidence/reactor.go +++ b/internal/evidence/reactor.go @@ -201,7 +201,7 @@ func (r *Reactor) processEvidenceCh(ctx context.Context) { // // REF: https://github.com/tendermint/tendermint/issues/4727 func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate) { - r.logger.Debug("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + r.logger.Trace("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) r.mtx.Lock() defer r.mtx.Unlock() diff --git a/internal/libs/sync/mutexguard.go b/internal/libs/sync/mutexguard.go new file mode 100644 index 0000000000..5c1c1be18e --- /dev/null +++ b/internal/libs/sync/mutexguard.go @@ -0,0 +1,54 @@ +package sync + +// UnlockFn is a function that unlocks a mutex. +// It returns true if the mutex was unlocked, false if it was already unlocked. +type UnlockFn func() bool + +// Mtx is a mutex interface. +// Implemented by sync.Mutex and deadlock.Mutex. +type Mtx interface { + Lock() + Unlock() +} + +// RMtx is a mutex that can be locked for read. +// +// Implemented by sync.RwMutex and deadlock.RwMutex. +type RMtx interface { + RLock() + RUnlock() +} + +// LockGuard locks the mutex and returns a function that unlocks it. +// The returned function must be called to release the lock. +// The returned function may be called multiple times - only the first call will unlock the mutex, others will be no-ops. +func LockGuard(mtx Mtx) UnlockFn { + mtx.Lock() + locked := true + + return func() bool { + if locked { + locked = false + mtx.Unlock() + return true + } + return false + } +} + +// RLockGuard locks the read-write mutex for reading and returns a function that unlocks it. +// The returned function must be called to release the lock. +// The returned function may be called multiple times - only the first call will unlock the mutex, others will be no-ops. +func RLockGuard(mtx RMtx) UnlockFn { + mtx.RLock() + locked := true + + return func() bool { + if locked { + locked = false + mtx.RUnlock() + return true + } + return false + } +} diff --git a/internal/libs/sync/mutexguard_test.go b/internal/libs/sync/mutexguard_test.go new file mode 100644 index 0000000000..feafa118e2 --- /dev/null +++ b/internal/libs/sync/mutexguard_test.go @@ -0,0 +1,159 @@ +package sync_test + +import ( + "sync" + "testing" + "time" + + deadlock "github.com/sasha-s/go-deadlock" + "github.com/stretchr/testify/assert" + + tmsync "github.com/dashpay/tenderdash/internal/libs/sync" +) + +const ( + timeout = 1 * time.Second +) + +// TestLockGuardMultipleUnlocks checks that the LockGuard function correctly handles multiple unlock calls. +func TestLockGuardMultipleUnlocks(t *testing.T) { + // Disable deadlock detection logic for this test + deadlockDisabled := deadlock.Opts.Disable + deadlock.Opts.Disable = true + defer func() { + deadlock.Opts.Disable = deadlockDisabled + }() + var mtx deadlock.Mutex + + unlock := tmsync.LockGuard(&mtx) + // deferred unlock() will do nothing because we unlock inside the test, but we still want to check this + defer func() { assert.False(t, unlock()) }() + // locking should not be possible + assert.False(t, mtx.TryLock()) + + assert.True(t, unlock()) + // here we can lock + mtx.Lock() + // unlock should do nothing + assert.False(t, unlock()) + // locking again should not be possible + assert.False(t, mtx.TryLock()) + // unlock should do nothing + assert.False(t, unlock()) + // but this unlock should work + mtx.Unlock() + assert.True(t, mtx.TryLock()) +} + +// TestLockGuard checks that the LockGuard function correctly increments a counter using multiple goroutines. +func TestLockGuard(t *testing.T) { + var mtx deadlock.Mutex + var counter int + var wg sync.WaitGroup + + increment := func() { + defer wg.Done() + unlock := tmsync.LockGuard(&mtx) + defer unlock() + counter++ + } + + // Start multiple goroutines to increment the counter + for i := 0; i < 100; i++ { + wg.Add(1) + go increment() + } + waitFor(wg.Wait) + assert.Equal(t, 100, counter, "Counter should be incremented to 100") +} + +// TestRLockGuard checks that the RLockGuard function allows multiple read locks +// and correctly increments a counter using write locks. +func TestRLockGuard(t *testing.T) { + var mtx deadlock.RWMutex + var counter int + var wg sync.WaitGroup + + read := func() { + defer wg.Done() + unlock := tmsync.RLockGuard(&mtx) + defer unlock() + _ = counter // Just read the counter + } + + write := func() { + defer wg.Done() + unlock := tmsync.LockGuard(&mtx) + defer unlock() + counter++ + } + + // Start multiple goroutines to read the counter + for i := 0; i < 100; i++ { + wg.Add(1) + go read() + } + + // Start multiple goroutines to write to the counter + for i := 0; i < 10; i++ { + wg.Add(1) + go write() + } + + waitFor(wg.Wait) + assert.Equal(t, 10, counter, "Counter should be incremented to 10") +} + +// waitFor waits for the function `f` to finish or times out after `timeout`. +func waitFor(f func()) { + done := make(chan struct{}) + go func() { + f() + close(done) + }() + + select { + case <-time.After(timeout): + panic("Test timed out") + case <-done: + } +} + +// TestMixedLocks checks the behavior of mixed read and write locks, +// ensuring that the counter is correctly incremented while allowing concurrent reads. +func TestMixedLocks(t *testing.T) { + var mtx deadlock.RWMutex + var counter int + var wg sync.WaitGroup + + read := func() { + defer wg.Done() + unlock := tmsync.RLockGuard(&mtx) + defer unlock() + time.Sleep(10 * time.Millisecond) // Simulate read delay + _ = counter // Just read the counter + } + + write := func() { + defer wg.Done() + unlock := tmsync.LockGuard(&mtx) + defer unlock() + counter++ + time.Sleep(10 * time.Millisecond) // Simulate write delay + } + + // Start multiple goroutines to read the counter + for i := 0; i < 50; i++ { + wg.Add(1) + go read() + } + + // Start multiple goroutines to write to the counter + for i := 0; i < 5; i++ { + wg.Add(1) + go write() + } + + waitFor(wg.Wait) + assert.Equal(t, 5, counter, "Counter should be incremented to 5") +} diff --git a/internal/mempool/reactor.go b/internal/mempool/reactor.go index 415fef8bec..c00c364359 100644 --- a/internal/mempool/reactor.go +++ b/internal/mempool/reactor.go @@ -94,7 +94,7 @@ func (r *Reactor) OnStop() {} // removed peers, we remove the peer from the mempool peer ID set and signal to // stop the tx broadcasting goroutine. func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate) { - r.logger.Debug("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + r.logger.Trace("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) r.mtx.Lock() defer r.mtx.Unlock() diff --git a/internal/p2p/transport_mconn.go b/internal/p2p/transport_mconn.go index 3757f8f7d4..920382b134 100644 --- a/internal/p2p/transport_mconn.go +++ b/internal/p2p/transport_mconn.go @@ -184,7 +184,7 @@ func (m *MConnTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connecti return nil, err } if endpoint.Port == 0 { - endpoint.Port = 26657 + endpoint.Port = 26656 } dialer := net.Dialer{} diff --git a/internal/statesync/chunks.go b/internal/statesync/chunks.go index bd9d23e880..8a11275a39 100644 --- a/internal/statesync/chunks.go +++ b/internal/statesync/chunks.go @@ -1,6 +1,8 @@ package statesync import ( + "crypto/sha256" + "encoding/hex" "errors" "fmt" "os" @@ -9,6 +11,8 @@ import ( sync "github.com/sasha-s/go-deadlock" + tmsync "github.com/dashpay/tenderdash/internal/libs/sync" + "github.com/dashpay/tenderdash/libs/bytes" "github.com/dashpay/tenderdash/types" ) @@ -63,6 +67,22 @@ type ( } ) +// Filename updates `chunkItem.file` with an absolute path to file containing the the chunk and returns it. +// If the filename is already set, it isn't changed. +// +// Returns error if the filename cannot be created. +// +// Caller must ensure only one goroutine calls this method at a time, eg. by holding the mutex lock. +func (c *chunkItem) Filename(parentDir string) (string, error) { + var err error + if c.file == "" { + hash := sha256.Sum256(c.chunkID) + filename := hex.EncodeToString(hash[:]) + c.file, err = filepath.Abs(filepath.Join(parentDir, filename)) + } + return c.file, err +} + // newChunkQueue creates a new chunk requestQueue for a snapshot, using a temp dir for storage. // Callers must call Close() when done. func newChunkQueue(snapshot *snapshot, tempDir string, bufLen int) (*chunkQueue, error) { @@ -133,36 +153,56 @@ func (q *chunkQueue) dequeue() (bytes.HexBytes, error) { // Add adds a chunk to the queue. It ignores chunks that already exist, returning false. func (q *chunkQueue) Add(chunk *chunk) (bool, error) { - if chunk == nil || chunk.Chunk == nil { + if chunk == nil { return false, errChunkNil } - q.mtx.Lock() - defer q.mtx.Unlock() - if q.snapshot == nil { - return false, errNilSnapshot + + // empty chunk content is allowed, but we ensure it's not nil + data := chunk.Chunk + if data == nil { + data = []byte{} } - chunkIDKey := chunk.ID.String() - item, ok := q.items[chunkIDKey] - if !ok { - return false, fmt.Errorf("failed to add the chunk %x, it was never requested", chunk.ID) + + unlockFn := tmsync.LockGuard(&q.mtx) + defer unlockFn() + + item, err := q.getItem(chunk.ID) + if err != nil { + return false, fmt.Errorf("get chunk %x: %w", chunk.ID, err) } + if item.status != inProgressStatus && item.status != discardedStatus { + // chunk either already exists, or we didn't request it yet, so we ignore it return false, nil } - err := q.validateChunk(chunk) + + err = q.validateChunk(chunk) if err != nil { - return false, err + return false, fmt.Errorf("validate chunk %x: %w", chunk.ID, err) + } + + // ensure filename is set on the item + _, err = item.Filename(q.dir) + if err != nil { + return false, fmt.Errorf("failed to get filename for chunk %x: %w", chunk.ID, err) } - item.file = filepath.Join(q.dir, chunkIDKey) - err = item.write(chunk.Chunk) + + err = item.write(data) if err != nil { return false, err } item.sender = chunk.Sender item.status = receivedStatus + + // unlock before sending to applyCh to avoid blocking/deadlock on the applyCh + unlockFn() + q.applyCh <- chunk.ID // Signal any waiters that the chunk has arrived. + q.mtx.Lock() item.closeWaitChs(true) + q.mtx.Unlock() + return true, nil } @@ -284,8 +324,25 @@ func (q *chunkQueue) Next() (*chunk, error) { q.doneCount++ return loadedChunk, nil case <-time.After(chunkTimeout): - return nil, errTimeout + // Locking is done inside q.Pending + pendingChunks := len(q.Pending()) + return nil, fmt.Errorf("timed out waiting for %d chunks: %w", pendingChunks, errTimeout) + } +} + +// Pending returns a list of all chunks that have been requested but not yet received. +func (q *chunkQueue) Pending() []bytes.HexBytes { + q.mtx.Lock() + defer q.mtx.Unlock() + + // get all keys from the map that don't have a status of received + waiting := make([]bytes.HexBytes, 0, len(q.items)) + for _, item := range q.items { + if item.status == initStatus || item.status == inProgressStatus { + waiting = append(waiting, item.chunkID) + } } + return waiting } // Retry schedules a chunk to be retried, without refetching it. @@ -296,12 +353,13 @@ func (q *chunkQueue) Retry(chunkID bytes.HexBytes) { } func (q *chunkQueue) retry(chunkID bytes.HexBytes) { - item, ok := q.items[chunkID.String()] + chunkKey := chunkID.String() + item, ok := q.items[chunkKey] if !ok || (item.status != receivedStatus && item.status != doneStatus) { return } q.requestQueue = append(q.requestQueue, chunkID) - q.items[chunkID.String()].status = initStatus + q.items[chunkKey].status = initStatus } // RetryAll schedules all chunks to be retried, without refetching them. @@ -345,6 +403,23 @@ func (q *chunkQueue) DoneChunksCount() int { return q.doneCount } +// getItem fetches chunk item from the items map. If the item is not found, it returns an error. +// The caller must hold the mutex lock. +func (q *chunkQueue) getItem(chunkID bytes.HexBytes) (*chunkItem, error) { + if q.snapshot == nil { + return nil, errNilSnapshot + } + chunkIDKey := chunkID.String() + item, ok := q.items[chunkIDKey] + if !ok { + return nil, fmt.Errorf("chunk %x not found", chunkID) + } + + return item, nil +} + +// validateChunk checks if the chunk is expected and valid for the current snapshot +// The caller must hold the mutex lock. func (q *chunkQueue) validateChunk(chunk *chunk) error { if chunk.Height != q.snapshot.Height { return fmt.Errorf("invalid chunk height %v, expected %v", diff --git a/internal/statesync/chunks_test.go b/internal/statesync/chunks_test.go index 0ecb01946e..e3be49683a 100644 --- a/internal/statesync/chunks_test.go +++ b/internal/statesync/chunks_test.go @@ -1,6 +1,7 @@ package statesync import ( + "errors" "os" "testing" @@ -100,10 +101,10 @@ func (suite *ChunkQueueTestSuite) TestChunkQueue() { {chunk: suite.chunks[1], want: true}, } require := suite.Require() - for _, tc := range testCases { + for i, tc := range testCases { added, err := suite.queue.Add(tc.chunk) - require.NoError(err) - require.Equal(tc.want, added) + require.NoError(err, "test case %d", i) + require.Equal(tc.want, added, "test case %d", i) } // At this point, we should be able to retrieve them all via Next @@ -244,7 +245,7 @@ func (suite *ChunkQueueTestSuite) TestNext() { go func() { for { c, err := suite.queue.Next() - if err == errDone { + if errors.Is(err, errDone) { close(chNext) break } @@ -284,7 +285,7 @@ func (suite *ChunkQueueTestSuite) TestNextClosed() { require.NoError(err) _, err = suite.queue.Next() - require.Equal(errDone, err) + require.ErrorIs(err, errDone) } func (suite *ChunkQueueTestSuite) TestRetry() { diff --git a/internal/statesync/mocks/consensusstateprovider.go b/internal/statesync/mocks/consensusstateprovider.go new file mode 100644 index 0000000000..4f12038ac5 --- /dev/null +++ b/internal/statesync/mocks/consensusstateprovider.go @@ -0,0 +1,127 @@ +// Code generated by mockery. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + types "github.com/dashpay/tenderdash/types" +) + +// ConsensusStateProvider is an autogenerated mock type for the ConsensusStateProvider type +type ConsensusStateProvider struct { + mock.Mock +} + +type ConsensusStateProvider_Expecter struct { + mock *mock.Mock +} + +func (_m *ConsensusStateProvider) EXPECT() *ConsensusStateProvider_Expecter { + return &ConsensusStateProvider_Expecter{mock: &_m.Mock} +} + +// GetCurrentHeight provides a mock function with no fields +func (_m *ConsensusStateProvider) GetCurrentHeight() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCurrentHeight") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// ConsensusStateProvider_GetCurrentHeight_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCurrentHeight' +type ConsensusStateProvider_GetCurrentHeight_Call struct { + *mock.Call +} + +// GetCurrentHeight is a helper method to define mock.On call +func (_e *ConsensusStateProvider_Expecter) GetCurrentHeight() *ConsensusStateProvider_GetCurrentHeight_Call { + return &ConsensusStateProvider_GetCurrentHeight_Call{Call: _e.mock.On("GetCurrentHeight")} +} + +func (_c *ConsensusStateProvider_GetCurrentHeight_Call) Run(run func()) *ConsensusStateProvider_GetCurrentHeight_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *ConsensusStateProvider_GetCurrentHeight_Call) Return(_a0 int64) *ConsensusStateProvider_GetCurrentHeight_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ConsensusStateProvider_GetCurrentHeight_Call) RunAndReturn(run func() int64) *ConsensusStateProvider_GetCurrentHeight_Call { + _c.Call.Return(run) + return _c +} + +// PublishCommitEvent provides a mock function with given fields: commit +func (_m *ConsensusStateProvider) PublishCommitEvent(commit *types.Commit) error { + ret := _m.Called(commit) + + if len(ret) == 0 { + panic("no return value specified for PublishCommitEvent") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*types.Commit) error); ok { + r0 = rf(commit) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ConsensusStateProvider_PublishCommitEvent_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PublishCommitEvent' +type ConsensusStateProvider_PublishCommitEvent_Call struct { + *mock.Call +} + +// PublishCommitEvent is a helper method to define mock.On call +// - commit *types.Commit +func (_e *ConsensusStateProvider_Expecter) PublishCommitEvent(commit interface{}) *ConsensusStateProvider_PublishCommitEvent_Call { + return &ConsensusStateProvider_PublishCommitEvent_Call{Call: _e.mock.On("PublishCommitEvent", commit)} +} + +func (_c *ConsensusStateProvider_PublishCommitEvent_Call) Run(run func(commit *types.Commit)) *ConsensusStateProvider_PublishCommitEvent_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*types.Commit)) + }) + return _c +} + +func (_c *ConsensusStateProvider_PublishCommitEvent_Call) Return(_a0 error) *ConsensusStateProvider_PublishCommitEvent_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ConsensusStateProvider_PublishCommitEvent_Call) RunAndReturn(run func(*types.Commit) error) *ConsensusStateProvider_PublishCommitEvent_Call { + _c.Call.Return(run) + return _c +} + +// NewConsensusStateProvider creates a new instance of ConsensusStateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewConsensusStateProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *ConsensusStateProvider { + mock := &ConsensusStateProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/statesync/peer.go b/internal/statesync/peer.go index 589acecacc..3e1e49d3d7 100644 --- a/internal/statesync/peer.go +++ b/internal/statesync/peer.go @@ -96,7 +96,7 @@ func (p *PeerSubscriber) Stop(ctx context.Context) { // processPeerUpdate processes a PeerUpdate, returning an error upon failing to // handle the PeerUpdate or if a panic is recovered. func (p *PeerSubscriber) execute(ctx context.Context, peerUpdate p2p.PeerUpdate) error { - p.logger.Info("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + p.logger.Trace("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) handler, ok := p.handles[peerUpdate.Status] if !ok { // TODO: return error or write a log @@ -106,7 +106,7 @@ func (p *PeerSubscriber) execute(ctx context.Context, peerUpdate p2p.PeerUpdate) if err != nil { return err } - p.logger.Info("processed peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + p.logger.Trace("processed peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) return nil } diff --git a/internal/statesync/reactor.go b/internal/statesync/reactor.go index e31088ff4e..679c242df6 100644 --- a/internal/statesync/reactor.go +++ b/internal/statesync/reactor.go @@ -16,12 +16,12 @@ import ( abci "github.com/dashpay/tenderdash/abci/types" "github.com/dashpay/tenderdash/config" dashcore "github.com/dashpay/tenderdash/dash/core" - "github.com/dashpay/tenderdash/internal/consensus" "github.com/dashpay/tenderdash/internal/eventbus" "github.com/dashpay/tenderdash/internal/p2p" sm "github.com/dashpay/tenderdash/internal/state" "github.com/dashpay/tenderdash/internal/store" "github.com/dashpay/tenderdash/libs/log" + tmmath "github.com/dashpay/tenderdash/libs/math" "github.com/dashpay/tenderdash/libs/service" "github.com/dashpay/tenderdash/light/provider" ssproto "github.com/dashpay/tenderdash/proto/tendermint/statesync" @@ -68,6 +68,9 @@ const ( // backfillSleepTime uses to sleep if no connected peers to fetch light blocks backfillSleepTime = 1 * time.Second + + // minPeers is the minimum number of peers required to start a state sync + minPeers = 2 ) func getChannelDescriptors() map[p2p.ChannelID]*p2p.ChannelDescriptor { @@ -120,7 +123,7 @@ type Reactor struct { // providers. mtx sync.RWMutex initSyncer func() *syncer - requestSnaphot func() error + requestSnapshot func() error syncer *syncer // syncer is nil when sync is not in progress initStateProvider func(ctx context.Context, chainID string, initialHeight int64) error stateProvider StateProvider @@ -132,7 +135,16 @@ type Reactor struct { dashCoreClient dashcore.Client - csState *consensus.State + csState ConsensusStateProvider +} + +// ConsensusStateProvider is an interface that allows the state sync reactor to +// interact with the consensus state. It is defined to improve testability. +// +// Implemented by consensus.State +type ConsensusStateProvider interface { + PublishCommitEvent(commit *types.Commit) error + GetCurrentHeight() int64 } // NewReactor returns a reference to a new state sync reactor, which implements @@ -155,7 +167,7 @@ func NewReactor( postSyncHook func(context.Context, sm.State) error, needsStateSync bool, client dashcore.Client, - csState *consensus.State, + csState ConsensusStateProvider, ) *Reactor { r := &Reactor{ logger: logger, @@ -225,7 +237,7 @@ func (r *Reactor) OnStart(ctx context.Context) error { } } r.dispatcher = NewDispatcher(blockCh, r.logger) - r.requestSnaphot = func() error { + r.requestSnapshot = func() error { // request snapshots from all currently connected peers return snapshotCh.Send(ctx, p2p.Envelope{ Broadcast: true, @@ -236,11 +248,10 @@ func (r *Reactor) OnStart(ctx context.Context) error { r.initStateProvider = func(ctx context.Context, chainID string, initialHeight int64) error { spLogger := r.logger.With("module", "stateprovider") - spLogger.Info("initializing state provider", - "trustHeight", r.cfg.TrustHeight, "useP2P", r.cfg.UseP2P) + spLogger.Debug("initializing state sync state provider", "useP2P", r.cfg.UseP2P) if r.cfg.UseP2P { - if err := r.waitForEnoughPeers(ctx, 2); err != nil { + if err := r.waitForEnoughPeers(ctx, minPeers); err != nil { return err } @@ -250,8 +261,8 @@ func (r *Reactor) OnStart(ctx context.Context) error { providers[idx] = NewBlockProvider(p, chainID, r.dispatcher) } - stateProvider, err := NewP2PStateProvider(ctx, chainID, initialHeight, r.cfg.TrustHeight, providers, - paramsCh, r.logger.With("module", "stateprovider"), r.dashCoreClient) + stateProvider, err := NewP2PStateProvider(ctx, chainID, initialHeight, + providers, paramsCh, r.logger.With("module", "stateprovider"), r.dashCoreClient) if err != nil { return fmt.Errorf("failed to initialize P2P state provider: %w", err) } @@ -259,8 +270,7 @@ func (r *Reactor) OnStart(ctx context.Context) error { return nil } - stateProvider, err := NewRPCStateProvider(ctx, chainID, initialHeight, r.cfg.RPCServers, r.cfg.TrustHeight, - spLogger, r.dashCoreClient) + stateProvider, err := NewRPCStateProvider(ctx, chainID, initialHeight, r.cfg.RPCServers, spLogger, r.dashCoreClient) if err != nil { return fmt.Errorf("failed to initialize RPC state provider: %w", err) } @@ -279,8 +289,21 @@ func (r *Reactor) OnStart(ctx context.Context) error { if r.needsStateSync { r.logger.Info("starting state sync") if _, err := r.Sync(ctx); err != nil { - r.logger.Error("state sync failed; shutting down this node", "error", err) - return err + if errors.Is(err, errNoSnapshots) && r.postSyncHook != nil { + r.logger.Warn("no snapshots available; falling back to block sync", "err", err) + + state, err := r.stateStore.Load() + if err != nil { + return fmt.Errorf("failed to load state: %w", err) + } + + if err := r.postSyncHook(ctx, state); err != nil { + return fmt.Errorf("post sync failed: %w", err) + } + } else { + r.logger.Error("state sync failed; shutting down this node", "err", err) + return err + } } } @@ -311,7 +334,7 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { // We need at least two peers (for cross-referencing of light blocks) before we can // begin state sync - if err := r.waitForEnoughPeers(ctx, 2); err != nil { + if err := r.waitForEnoughPeers(ctx, minPeers); err != nil { return sm.State{}, fmt.Errorf("wait for peers: %w", err) } @@ -329,7 +352,7 @@ func (r *Reactor) Sync(ctx context.Context) (sm.State, error) { } r.getSyncer().SetStateProvider(r.stateProvider) - state, commit, err := r.syncer.SyncAny(ctx, r.cfg.DiscoveryTime, r.requestSnaphot) + state, commit, err := r.syncer.SyncAny(ctx, r.cfg.DiscoveryTime, r.cfg.Retries, r.requestSnapshot) if err != nil { return sm.State{}, fmt.Errorf("sync any: %w", err) } @@ -423,6 +446,9 @@ func (r *Reactor) Backfill(ctx context.Context, state sm.State) error { params := state.ConsensusParams.Evidence stopHeight := state.LastBlockHeight - params.MaxAgeNumBlocks stopTime := state.LastBlockTime.Add(-params.MaxAgeDuration) + // To make tests on mainnet faster, we can use: + // stopHeight := state.LastBlockHeight - 500 + // stopTime := state.LastBlockTime.Add(-24 * time.Hour) // ensure that stop height doesn't go below the initial height if stopHeight < state.InitialHeight { stopHeight = state.InitialHeight @@ -696,7 +722,7 @@ func (r *Reactor) handleSnapshotMessage(ctx context.Context, envelope *p2p.Envel "version", msg.Version) default: - return fmt.Errorf("received unknown message: %T", msg) + return fmt.Errorf("handleSnapshotMessage received unknown message: %T", msg) } return nil @@ -777,7 +803,7 @@ func (r *Reactor) handleChunkMessage(ctx context.Context, envelope *p2p.Envelope } default: - return fmt.Errorf("received unknown message: %T", msg) + return fmt.Errorf("handleChunkMessage received unknown message: %T", msg) } return nil @@ -838,7 +864,7 @@ func (r *Reactor) handleLightBlockMessage(ctx context.Context, envelope *p2p.Env } default: - return fmt.Errorf("received unknown message: %T", msg) + return fmt.Errorf("handleLightBlockMessage received unknown message: %T", msg) } return nil @@ -848,7 +874,7 @@ func (r *Reactor) handleParamsMessage(ctx context.Context, envelope *p2p.Envelop switch msg := envelope.Message.(type) { case *ssproto.ParamsRequest: r.logger.Debug("received consensus params request", "height", msg.Height) - cp, err := r.stateStore.LoadConsensusParams(int64(msg.Height)) + cp, err := r.stateStore.LoadConsensusParams(tmmath.MustConvertInt64(msg.Height)) if err != nil { r.logger.Error("failed to fetch requested consensus params", "height", msg.Height, @@ -886,7 +912,7 @@ func (r *Reactor) handleParamsMessage(ctx context.Context, envelope *p2p.Envelop } default: - return fmt.Errorf("received unknown message: %T", msg) + return fmt.Errorf("handleParamsMessage received unknown message: %T", msg) } return nil @@ -971,7 +997,7 @@ func (r *Reactor) processChannels(ctx context.Context, chanTable map[p2p.Channel // processPeerUpdate processes a PeerUpdate, returning an error upon failing to // handle the PeerUpdate or if a panic is recovered. func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpdate) { - r.logger.Info("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + r.logger.Trace("received peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) switch peerUpdate.Status { case p2p.PeerStatusUp: @@ -980,7 +1006,7 @@ func (r *Reactor) processPeerUpdate(ctx context.Context, peerUpdate p2p.PeerUpda r.processPeerDown(ctx, peerUpdate) } - r.logger.Info("processed peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) + r.logger.Trace("processed peer update", "peer", peerUpdate.NodeID, "status", peerUpdate.Status) } func (r *Reactor) processPeerUp(ctx context.Context, peerUpdate p2p.PeerUpdate) { @@ -1038,6 +1064,12 @@ func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerU // recentSnapshots fetches the n most recent snapshots from the app func (r *Reactor) recentSnapshots(ctx context.Context, n uint32) ([]*snapshot, error) { + // if we don't have current state, we don't return any snapshots + if r.csState == nil { + return []*snapshot{}, nil + } + currentHeight := r.csState.GetCurrentHeight() + resp, err := r.conn.ListSnapshots(ctx, &abci.RequestListSnapshots{}) if err != nil { return nil, err @@ -1063,6 +1095,14 @@ func (r *Reactor) recentSnapshots(ctx context.Context, n uint32) ([]*snapshot, e break } + // we only accept snapshots where next block is already finalized, that is we are voting + // for `height + 2` or higher, because we need to be able to fetch light block containing + // commit for `height` from block store (which is stored in block `height+1`) + if tmmath.MustConvertInt64(s.Height) >= currentHeight-2 { + r.logger.Debug("snapshot too new, skipping", "height", s.Height, "state_height", currentHeight) + continue + } + snapshots = append(snapshots, &snapshot{ Height: s.Height, Version: s.Version, @@ -1077,7 +1117,7 @@ func (r *Reactor) recentSnapshots(ctx context.Context, n uint32) ([]*snapshot, e // fetchLightBlock works out whether the node has a light block at a particular // height and if so returns it so it can be gossiped to peers func (r *Reactor) fetchLightBlock(height uint64) (*types.LightBlock, error) { - h := int64(height) + h := tmmath.MustConvertInt64(height) blockMeta := r.blockStore.LoadBlockMeta(h) if blockMeta == nil { diff --git a/internal/statesync/reactor_test.go b/internal/statesync/reactor_test.go index eff91961c4..f8fd92e9c7 100644 --- a/internal/statesync/reactor_test.go +++ b/internal/statesync/reactor_test.go @@ -21,6 +21,7 @@ import ( "github.com/dashpay/tenderdash/config" "github.com/dashpay/tenderdash/crypto" dashcore "github.com/dashpay/tenderdash/dash/core" + "github.com/dashpay/tenderdash/internal/p2p" "github.com/dashpay/tenderdash/internal/proxy" smmocks "github.com/dashpay/tenderdash/internal/state/mocks" @@ -86,6 +87,7 @@ func setup( t *testing.T, conn *clientmocks.Client, stateProvider *mocks.StateProvider, + csState ConsensusStateProvider, chBuf uint, ) *reactorTestSuite { t.Helper() @@ -187,7 +189,7 @@ func setup( nil, // post-sync-hook false, // run Sync during Start() rts.dashcoreClient, - nil, + csState, ) rts.syncer = &syncer{ @@ -224,7 +226,7 @@ func TestReactor_Sync(t *testing.T) { defer cancel() const snapshotHeight = 7 - rts := setup(ctx, t, nil, nil, 100) + rts := setup(ctx, t, nil, nil, nil, 100) chain := buildLightBlockChain(ctx, t, 1, 10, time.Now(), rts.privVal) // app accepts any snapshot rts.conn. @@ -277,8 +279,6 @@ func TestReactor_Sync(t *testing.T) { // update the config to use the p2p provider rts.reactor.cfg.UseP2P = true - rts.reactor.cfg.TrustHeight = 1 - rts.reactor.cfg.TrustHash = fmt.Sprintf("%X", chain[1].Hash()) rts.reactor.cfg.DiscoveryTime = 1 * time.Second // Run state sync @@ -290,7 +290,7 @@ func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rts := setup(ctx, t, nil, nil, 2) + rts := setup(ctx, t, nil, nil, nil, 2) rts.chunkInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -350,7 +350,7 @@ func TestReactor_ChunkRequest(t *testing.T) { ChunkId: tc.request.ChunkId, }).Return(&abci.ResponseLoadSnapshotChunk{Chunk: tc.chunk}, nil) - rts := setup(ctx, t, conn, nil, 2) + rts := setup(ctx, t, conn, nil, nil, 2) rts.chunkInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -371,7 +371,7 @@ func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rts := setup(ctx, t, nil, nil, 2) + rts := setup(ctx, t, nil, nil, nil, 2) rts.snapshotInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -390,8 +390,9 @@ func TestReactor_SnapshotsRequest(t *testing.T) { testcases := map[string]struct { snapshots []*abci.Snapshot expectResponses []*ssproto.SnapshotsResponse + currentHeight int64 }{ - "no snapshots": {nil, []*ssproto.SnapshotsResponse{}}, + "no snapshots": {nil, []*ssproto.SnapshotsResponse{}, 1}, ">10 unordered snapshots": { []*abci.Snapshot{ {Height: 1, Version: 2, Hash: []byte{1, 2}, Metadata: []byte{1}}, @@ -419,25 +420,27 @@ func TestReactor_SnapshotsRequest(t *testing.T) { {Height: 1, Version: 4, Hash: []byte{1, 4}, Metadata: []byte{7}}, {Height: 1, Version: 3, Hash: []byte{1, 3}, Metadata: []byte{10}}, }, + 6, }, } ctx, cancel := context.WithCancel(context.Background()) defer cancel() for name, tc := range testcases { - tc := tc - t.Run(name, func(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() // mock ABCI connection to return local snapshots - conn := &clientmocks.Client{} + conn := clientmocks.NewClient(t) conn.On("ListSnapshots", mock.Anything, &abci.RequestListSnapshots{}).Return(&abci.ResponseListSnapshots{ Snapshots: tc.snapshots, - }, nil) + }, nil).Maybe() - rts := setup(ctx, t, conn, nil, 100) + consensusStateProvider := mocks.NewConsensusStateProvider(t) + consensusStateProvider.On("GetCurrentHeight").Return(tc.currentHeight).Maybe() + + rts := setup(ctx, t, conn, nil, consensusStateProvider, 100) rts.snapshotInCh <- p2p.Envelope{ From: types.NodeID("aa"), @@ -446,7 +449,10 @@ func TestReactor_SnapshotsRequest(t *testing.T) { } if len(tc.expectResponses) > 0 { - retryUntil(ctx, t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) + retryUntil(ctx, t, + func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, + time.Second, + ) } responses := make([]*ssproto.SnapshotsResponse, len(tc.expectResponses)) @@ -465,7 +471,7 @@ func TestReactor_LightBlockResponse(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rts := setup(ctx, t, nil, nil, 2) + rts := setup(ctx, t, nil, nil, nil, 2) var height int64 = 10 // generates a random header @@ -523,7 +529,7 @@ func TestReactor_BlockProviders(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rts := setup(ctx, t, nil, nil, 2) + rts := setup(ctx, t, nil, nil, nil, 2) rts.peerUpdateCh <- p2p.PeerUpdate{ NodeID: "aa", Status: p2p.PeerStatusUp, @@ -590,7 +596,7 @@ func TestReactor_StateProviderP2P(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rts := setup(ctx, t, nil, nil, 2) + rts := setup(ctx, t, nil, nil, nil, 2) // make syncer non nil else test won't think we are state syncing rts.reactor.syncer = rts.syncer peerA := types.NodeID(strings.Repeat("a", 2*types.NodeIDByteLength)) @@ -612,8 +618,6 @@ func TestReactor_StateProviderP2P(t *testing.T) { go handleConsensusParamsRequest(ctx, t, rts.paramsOutCh, rts.paramsInCh, closeCh) rts.reactor.cfg.UseP2P = true - rts.reactor.cfg.TrustHeight = 1 - rts.reactor.cfg.TrustHash = fmt.Sprintf("%X", chain[1].Hash()) for _, p := range []types.NodeID{peerA, peerB} { if !rts.reactor.peers.Contains(p) { @@ -693,7 +697,7 @@ func TestReactor_Backfill(t *testing.T) { defer cancel() t.Cleanup(leaktest.CheckTimeout(t, 1*time.Minute)) - rts := setup(ctx, t, nil, nil, 21) + rts := setup(ctx, t, nil, nil, nil, 21) peers := genPeerIDs(tc.numPeers) for _, peer := range peers { @@ -764,6 +768,8 @@ func TestReactor_Backfill(t *testing.T) { // retryUntil will continue to evaluate fn and will return successfully when true // or fail when the timeout is reached. func retryUntil(ctx context.Context, t *testing.T, fn func() bool, timeout time.Duration) { + t.Helper() + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() diff --git a/internal/statesync/stateprovider.go b/internal/statesync/stateprovider.go index b10324acc5..c8cdd9d9d4 100644 --- a/internal/statesync/stateprovider.go +++ b/internal/statesync/stateprovider.go @@ -42,6 +42,8 @@ type StateProvider interface { State(ctx context.Context, height uint64) (sm.State, error) } +// stateProviderRPC is a state provider using RPC to communicate with light clients. +// Deprecated, will be removed in future. type stateProviderRPC struct { sync.Mutex // light.Client is not concurrency-safe lc *light.Client @@ -51,12 +53,12 @@ type stateProviderRPC struct { } // NewRPCStateProvider creates a new StateProvider using a light client and RPC clients. +// Deprecated, will be removed in future. func NewRPCStateProvider( ctx context.Context, chainID string, initialHeight int64, servers []string, - trustHeight int64, logger log.Logger, dashCoreClient dashcore.Client, ) (StateProvider, error) { @@ -77,8 +79,7 @@ func NewRPCStateProvider( // provider used by the light client and use it to fetch consensus parameters. providerRemotes[provider] = server } - - lc, err := light.NewClientAtHeight(ctx, trustHeight, chainID, providers[0], providers[1:], + lc, err := light.NewClient(ctx, chainID, providers[0], providers[1:], lightdb.New(dbm.NewMemDB()), dashCoreClient, light.Logger(logger)) if err != nil { return nil, err @@ -208,17 +209,16 @@ func NewP2PStateProvider( ctx context.Context, chainID string, initialHeight int64, - trustHeight int64, providers []lightprovider.Provider, paramsSendCh p2p.Channel, logger log.Logger, dashCoreClient dashcore.Client, ) (StateProvider, error) { - if len(providers) < 2 { - return nil, fmt.Errorf("at least 2 peers are required, got %d", len(providers)) + if len(providers) < minPeers { + return nil, fmt.Errorf("at least %d peers are required, got %d", minPeers, len(providers)) } - lc, err := light.NewClientAtHeight(ctx, trustHeight, chainID, providers[0], providers[1:], + lc, err := light.NewClient(ctx, chainID, providers[0], providers[1:], lightdb.New(dbm.NewMemDB()), dashCoreClient, light.Logger(logger)) if err != nil { return nil, err diff --git a/internal/statesync/syncer.go b/internal/statesync/syncer.go index 24ce2eb0a6..d72f360094 100644 --- a/internal/statesync/syncer.go +++ b/internal/statesync/syncer.go @@ -16,18 +16,22 @@ import ( sm "github.com/dashpay/tenderdash/internal/state" tmbytes "github.com/dashpay/tenderdash/libs/bytes" "github.com/dashpay/tenderdash/libs/log" + tmmath "github.com/dashpay/tenderdash/libs/math" "github.com/dashpay/tenderdash/light" ssproto "github.com/dashpay/tenderdash/proto/tendermint/statesync" "github.com/dashpay/tenderdash/types" ) const ( + // chunkTimeout is the timeout while waiting for the next chunk from the chunk queue. chunkTimeout = 2 * time.Minute // minimumDiscoveryTime is the lowest allowable time for a // SyncAny discovery time. minimumDiscoveryTime = 5 * time.Second + // chunkRequestSendTimeout is the timeout sending chunk requests to peers. + chunkRequestSendTimeout = 5 * time.Second dequeueChunkIDTimeoutDefault = 2 * time.Second ) @@ -89,7 +93,7 @@ func (s *syncer) AddChunk(chunk *chunk) (bool, error) { keyVals := []any{ "height", chunk.Height, "version", chunk.Version, - "chunk", chunk.ID, + "chunkID", chunk.ID, } added, err := s.chunkQueue.Add(chunk) if err != nil { @@ -120,6 +124,9 @@ func (s *syncer) AddSnapshot(peerID types.NodeID, snapshot *snapshot) (bool, err "height", snapshot.Height, "version", snapshot.Version, "hash", snapshot.Hash.ShortString()) + } else { + s.logger.Debug("snapshot not added, possibly duplicate or invalid", + "height", snapshot.Height, "hash", snapshot.Hash) } return added, nil } @@ -147,6 +154,7 @@ func (s *syncer) RemovePeer(peerID types.NodeID) { func (s *syncer) SyncAny( ctx context.Context, discoveryTime time.Duration, + retries int, requestSnapshots func() error, ) (sm.State, *types.Commit, error) { if discoveryTime != 0 && discoveryTime < minimumDiscoveryTime { @@ -156,19 +164,6 @@ func (s *syncer) SyncAny( timer := time.NewTimer(discoveryTime) defer timer.Stop() - if discoveryTime > 0 { - if err := requestSnapshots(); err != nil { - return sm.State{}, nil, err - } - s.logger.Info("discovering snapshots", - "interval", discoveryTime) - select { - case <-ctx.Done(): - return sm.State{}, nil, ctx.Err() - case <-timer.C: - } - } - // The app may ask us to retry a snapshot restoration, in which case we need to reuse // the snapshot and chunk queue from the previous loop iteration. var ( @@ -179,6 +174,11 @@ func (s *syncer) SyncAny( ) for { + // we loop one more time than `retries` to check if snapshots requested in previous iterations are available + if retries > 0 && snapshot == nil && iters > retries { + return sm.State{}, nil, errNoSnapshots + } + iters++ // If not nil, we're going to retry restoration of the same snapshot. if snapshot == nil { @@ -189,6 +189,10 @@ func (s *syncer) SyncAny( if discoveryTime == 0 { return sm.State{}, nil, errNoSnapshots } + // we re-request snapshots + if err := requestSnapshots(); err != nil { + return sm.State{}, nil, err + } s.logger.Info("discovering snapshots", "iterations", iters, "interval", discoveryTime) @@ -215,7 +219,7 @@ func (s *syncer) SyncAny( switch { case err == nil: s.metrics.SnapshotHeight.Set(float64(snapshot.Height)) - s.lastSyncedSnapshotHeight = int64(snapshot.Height) + s.lastSyncedSnapshotHeight = tmmath.MustConvertInt64(snapshot.Height) return newState, commit, nil case errors.Is(err, errAbort): @@ -342,7 +346,7 @@ func (s *syncer) Sync(ctx context.Context, snapshot *snapshot, queue *chunkQueue if ctx.Err() != nil { return sm.State{}, nil, ctx.Err() } - if err == light.ErrNoWitnesses { + if errors.Is(err, light.ErrNoWitnesses) { return sm.State{}, nil, fmt.Errorf("failed to get tendermint state at height %d. No witnesses remaining", snapshot.Height) } @@ -360,7 +364,7 @@ func (s *syncer) Sync(ctx context.Context, snapshot *snapshot, queue *chunkQueue return sm.State{}, nil, fmt.Errorf("failed to get commit at height %d. No witnesses remaining", snapshot.Height) } - s.logger.Info("failed to get and verify commit. Dropping snapshot and trying again", + s.logger.Error("failed to get and verify light block. Dropping snapshot and trying again", "err", err, "height", snapshot.Height) return sm.State{}, nil, errRejectSnapshot } @@ -509,31 +513,41 @@ func (s *syncer) fetchChunks(ctx context.Context, snapshot *snapshot, queue *chu } for { if queue.IsRequestQueueEmpty() { + s.logger.Debug("fetchChunks queue empty, waiting for chunk", "timeout", dequeueChunkIDTimeout) select { case <-ctx.Done(): + s.logger.Debug("fetchChunks context done on empty queue") return case <-time.After(dequeueChunkIDTimeout): + s.logger.Debug("fetchChunks queue empty, timed out", "timeout", dequeueChunkIDTimeout) continue } } ID, err := queue.Dequeue() if errors.Is(err, errQueueEmpty) { + s.logger.Debug("fetchChunks queue empty, waiting for chunk", "timeout", dequeueChunkIDTimeout, "err", err) continue } s.logger.Info("Fetching snapshot chunk", "height", snapshot.Height, "version", snapshot.Version, - "chunk", ID) + "chunkID", ID) ticker.Reset(s.retryTimeout) if err := s.requestChunk(ctx, snapshot, ID); err != nil { + s.logger.Error("failed to request snapshot chunk", "err", err, "chunkID", ID) + // retry the chunk + s.chunkQueue.Enqueue(ID) return } select { case <-queue.WaitFor(ID): // do nothing case <-ticker.C: + s.logger.Debug("chunk not received on time, retrying", + "chunkID", ID, "timeout", s.retryTimeout) s.chunkQueue.Enqueue(ID) case <-ctx.Done(): + s.logger.Debug("fetchChunks context done while waiting for chunk") return } } @@ -568,8 +582,10 @@ func (s *syncer) requestChunk(ctx context.Context, snapshot *snapshot, chunkID t ChunkId: chunkID, }, } + sCtx, cancel := context.WithTimeout(ctx, chunkRequestSendTimeout) + defer cancel() - return s.chunkCh.Send(ctx, msg) + return s.chunkCh.Send(sCtx, msg) } // verifyApp verifies the sync, checking the app hash, last block height and app version @@ -595,7 +611,7 @@ func (s *syncer) verifyApp(ctx context.Context, snapshot *snapshot, appVersion u return errVerifyFailed } - if uint64(resp.LastBlockHeight) != snapshot.Height { + if tmmath.MustConvertUint64(resp.LastBlockHeight) != snapshot.Height { s.logger.Error( "ABCI app reported unexpected last block height", "expected", snapshot.Height, diff --git a/internal/statesync/syncer_test.go b/internal/statesync/syncer_test.go index 00da83c9ed..2829aa7fe7 100644 --- a/internal/statesync/syncer_test.go +++ b/internal/statesync/syncer_test.go @@ -83,7 +83,7 @@ func (suite *SyncerTestSuite) TestSyncAny() { }, Software: version.TMCoreSemVer, }, - + InitialHeight: 1, LastBlockHeight: 1, LastBlockID: types.BlockID{Hash: []byte("blockhash")}, LastBlockTime: time.Now(), @@ -100,7 +100,10 @@ func (suite *SyncerTestSuite) TestSyncAny() { ConsensusParams: *types.DefaultConsensusParams(), LastHeightConsensusParamsChanged: 1, } - commit := &types.Commit{BlockID: types.BlockID{Hash: []byte("blockhash")}} + commit := &types.Commit{ + Height: 1, + BlockID: types.BlockID{Hash: []byte("blockhash")}, + } s := &snapshot{Height: 1, Version: 1, Hash: []byte{0}} chunks := []*chunk{ @@ -254,6 +257,7 @@ func (suite *SyncerTestSuite) TestSyncAny() { Once(). Return(asc.resp, nil) } + suite.conn. On("Info", mock.Anything, &proxy.RequestInfo). Once(). @@ -263,7 +267,7 @@ func (suite *SyncerTestSuite) TestSyncAny() { LastBlockAppHash: []byte("app_hash"), }, nil) - newState, lastCommit, err := suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + newState, lastCommit, err := suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().NoError(err) suite.Require().Equal([]int{0: 2, 1: 1, 2: 1, 3: 1}, chunkRequests) @@ -280,7 +284,7 @@ func (suite *SyncerTestSuite) TestSyncAnyNoSnapshots() { ctx, cancel := context.WithCancel(suite.ctx) defer cancel() - _, _, err := suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + _, _, err := suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().Equal(errNoSnapshots, err) } @@ -306,7 +310,7 @@ func (suite *SyncerTestSuite) TestSyncAnyAbort() { Once(). Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) - _, _, err = suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + _, _, err = suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().Equal(errAbort, err) } @@ -356,7 +360,7 @@ func (suite *SyncerTestSuite) TestSyncAnyReject() { Once(). Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - _, _, err = suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + _, _, err = suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().Equal(errNoSnapshots, err) } @@ -399,7 +403,7 @@ func (suite *SyncerTestSuite) TestSyncAnyRejectFormat() { Once(). Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) - _, _, err = suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + _, _, err = suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().Equal(errAbort, err) } @@ -453,7 +457,7 @@ func (suite *SyncerTestSuite) TestSyncAnyRejectSender() { Once(). Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - _, _, err := suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + _, _, err := suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().Equal(errNoSnapshots, err) } @@ -481,7 +485,7 @@ func (suite *SyncerTestSuite) TestSyncAnyAbciError() { Once(). Return(nil, errBoom) - _, _, err = suite.syncer.SyncAny(ctx, 0, func() error { return nil }) + _, _, err = suite.syncer.SyncAny(ctx, 0, 0, func() error { return nil }) suite.Require().True(errors.Is(err, errBoom)) } @@ -616,8 +620,8 @@ func (suite *SyncerTestSuite) TestApplyChunksResults() { fetchStartTime := time.Now() - c := &chunk{Height: 1, Version: 1, ID: []byte{0}, Chunk: body} - chunks.Enqueue(c.ID) + chunkID := []byte{0} + chunks.Enqueue(chunkID) for _, resp := range tc.resps { suite.conn. diff --git a/libs/math/safemath.go b/libs/math/safemath.go index 9afb409b21..3d6a5ca433 100644 --- a/libs/math/safemath.go +++ b/libs/math/safemath.go @@ -6,11 +6,7 @@ import ( "math" ) -var ErrOverflowInt64 = errors.New("int64 overflow") -var ErrOverflowInt32 = errors.New("int32 overflow") -var ErrOverflowUint32 = errors.New("uint32 overflow") -var ErrOverflowUint8 = errors.New("uint8 overflow") -var ErrOverflowInt8 = errors.New("int8 overflow") +var ErrOverflow = errors.New("integer overflow") // SafeAddClipInt64 adds two int64 integers and clips the result to the int64 range. func SafeAddClipInt64(a, b int64) int64 { @@ -27,9 +23,9 @@ func SafeAddClipInt64(a, b int64) int64 { // SafeAddInt64 adds two int64 integers. func SafeAddInt64(a, b int64) (int64, error) { if b > 0 && (a > math.MaxInt64-b) { - return 0, ErrOverflowInt64 + return 0, ErrOverflow } else if b < 0 && (a < math.MinInt64-b) { - return 0, ErrOverflowInt64 + return 0, ErrOverflow } return a + b, nil } @@ -37,9 +33,9 @@ func SafeAddInt64(a, b int64) (int64, error) { // SafeAddInt32 adds two int32 integers. func SafeAddInt32(a, b int32) (int32, error) { if b > 0 && (a > math.MaxInt32-b) { - return 0, ErrOverflowInt32 + return 0, ErrOverflow } else if b < 0 && (a < math.MinInt32-b) { - return 0, ErrOverflowInt32 + return 0, ErrOverflow } return a + b, nil } @@ -67,9 +63,9 @@ func SafeSubClipInt64(a, b int64) int64 { // SafeSubInt32 subtracts two int32 integers. func SafeSubInt32(a, b int32) (int32, error) { if b > 0 && (a < math.MinInt32+b) { - return 0, ErrOverflowInt32 + return 0, ErrOverflow } else if b < 0 && (a > math.MaxInt32+b) { - return 0, ErrOverflowInt32 + return 0, ErrOverflow } return a - b, nil } @@ -77,9 +73,9 @@ func SafeSubInt32(a, b int32) (int32, error) { // SafeConvertInt32 takes a int and checks if it overflows. func SafeConvertInt32[T Integer](a T) (int32, error) { if int64(a) > math.MaxInt32 { - return 0, ErrOverflowInt32 + return 0, ErrOverflow } else if int64(a) < math.MinInt32 { - return 0, ErrOverflowInt32 + return 0, ErrOverflow } return int32(a), nil } @@ -87,17 +83,121 @@ func SafeConvertInt32[T Integer](a T) (int32, error) { // SafeConvertInt32 takes a int and checks if it overflows. func SafeConvertUint32[T Integer](a T) (uint32, error) { if uint64(a) > math.MaxUint32 { - return 0, ErrOverflowUint32 + return 0, ErrOverflow } else if a < 0 { - return 0, ErrOverflowUint32 + return 0, ErrOverflow } return uint32(a), nil } +// SafeConvertUint64 takes a int and checks if it overflows. +func SafeConvertUint64[T Integer](a T) (uint64, error) { + return SafeConvert[T, uint64](a) +} + +// SafeConvertInt64 takes a int and checks if it overflows. +func SafeConvertInt64[T Integer](a T) (int64, error) { + return SafeConvert[T, int64](a) +} + +// SafeConvertInt16 takes a int and checks if it overflows. +func SafeConvertInt16[T Integer](a T) (int16, error) { + return SafeConvert[T, int16](a) +} + +// SafeConvertUint16 takes a int and checks if it overflows. +func SafeConvertUint16[T Integer](a T) (uint16, error) { + return SafeConvert[T, uint16](a) +} + type Integer interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 } +// SafeConvert converts a value of type T to a value of type U. +// It returns an error if the conversion would cause an overflow. +func SafeConvert[F Integer, T Integer](from F) (T, error) { + // check if int and uint are smaller than int64 and uint64; we use constants here for performance reasons + const uintIsSmall = math.MaxUint < math.MaxUint64 + const intIsSmall = math.MaxInt < math.MaxInt64 && math.MinInt > math.MinInt64 + + // special case for int64 and uint64 inputs; all other types are safe to convert to int64 + switch any(from).(type) { + case int64: + // conversion from int64 to uint64 - we need to check for negative values + if _, ok := any(T(0)).(uint64); ok && from < 0 { + return 0, ErrOverflow + } + // return T(from), nil + case uint64: + // conversion from uint64 to int64 - we need to check for overflow + if _, ok := any(T(0)).(int64); ok && uint64(from) > math.MaxInt64 { + return 0, ErrOverflow + } + // return T(from), nil + case int: + if !intIsSmall { + // when int isn't smaller than int64, we just fall back to int64 + return SafeConvert[int64, T](int64(from)) + } + // no return here - it's safe to use normal logic + case uint: + if !uintIsSmall { + // when uint isn't smaller than uint64, we just fall back to uint64 + return SafeConvert[uint64, T](uint64(from)) + } + // no return here - it's safe to use normal logic + } + if from >= 0 && uint64(from) > Max[T]() { + return 0, ErrOverflow + } + if from <= 0 && int64(from) < Min[T]() { + return 0, ErrOverflow + } + return T(from), nil +} + +// MustConvert converts a value of type T to a value of type U. +// It panics if the conversion would cause an overflow. +// +// See SafeConvert for non-panicking version. +func MustConvert[FROM Integer, TO Integer](a FROM) TO { + i, err := SafeConvert[FROM, TO](a) + if err != nil { + var zero TO + panic(fmt.Errorf("cannot convert %d to %T: %w", a, zero, err)) + } + return i +} + +func MustConvertUint64[T Integer](a T) uint64 { + return MustConvert[T, uint64](a) +} + +func MustConvertInt64[T Integer](a T) int64 { + return MustConvert[T, int64](a) +} + +func MustConvertUint16[T Integer](a T) uint16 { + return MustConvert[T, uint16](a) +} + +func MustConvertInt16[T Integer](a T) int16 { + return MustConvert[T, int16](a) +} + +func MustConvertUint8[T Integer](a T) uint8 { + return MustConvert[T, uint8](a) +} + +func MustConvertUint[T Integer](a T) uint { + return MustConvert[T, uint](a) +} + +func MustConvertInt[T Integer](a T) int { + return MustConvert[T, int](a) +} + // MustConvertInt32 takes an Integer and converts it to int32. // Panics if the conversion overflows. func MustConvertInt32[T Integer](a T) int32 { @@ -121,9 +221,9 @@ func MustConvertUint32[T Integer](a T) uint32 { // SafeConvertUint8 takes an int64 and checks if it overflows. func SafeConvertUint8(a int64) (uint8, error) { if a > math.MaxUint8 { - return 0, ErrOverflowUint8 + return 0, ErrOverflow } else if a < 0 { - return 0, ErrOverflowUint8 + return 0, ErrOverflow } return uint8(a), nil } @@ -131,9 +231,9 @@ func SafeConvertUint8(a int64) (uint8, error) { // SafeConvertInt8 takes an int64 and checks if it overflows. func SafeConvertInt8(a int64) (int8, error) { if a > math.MaxInt8 { - return 0, ErrOverflowInt8 + return 0, ErrOverflow } else if a < math.MinInt8 { - return 0, ErrOverflowInt8 + return 0, ErrOverflow } return int8(a), nil } @@ -159,3 +259,54 @@ func SafeMulInt64(a, b int64) (int64, bool) { return a * b, false } + +// Max returns the maximum value for a type T. +// +// The function panics if the type is not supported. +func Max[T Integer]() uint64 { + var max T + switch any(max).(type) { + case int: + return uint64(math.MaxInt) + case int8: + return uint64(math.MaxInt8) + case int16: + return uint64(math.MaxInt16) + case int32: + return uint64(math.MaxInt32) + case int64: + return uint64(math.MaxInt64) + case uint: + return uint64(math.MaxUint) + case uint8: + return uint64(math.MaxUint8) + case uint16: + return uint64(math.MaxUint16) + case uint32: + return uint64(math.MaxUint32) + case uint64: + return uint64(math.MaxUint64) + default: + panic(fmt.Sprintf("unsupported type %T", max)) + } +} + +// Min returns the minimum value for a type T. +func Min[T Integer]() int64 { + switch any(T(0)).(type) { + case int: + return int64(math.MinInt) + case int8: + return int64(math.MinInt8) + case int16: + return int64(math.MinInt16) + case int32: + return int64(math.MinInt32) + case int64: + return math.MinInt64 + case uint, uint8, uint16, uint32, uint64: + return 0 + default: + panic("unsupported type") + } +} diff --git a/libs/math/safemath_test.go b/libs/math/safemath_test.go index 92a8f32110..301c756f48 100644 --- a/libs/math/safemath_test.go +++ b/libs/math/safemath_test.go @@ -1,6 +1,7 @@ package math import ( + "fmt" "math" "testing" "testing/quick" @@ -84,3 +85,129 @@ func TestSafeMul(t *testing.T) { assert.Equal(t, tc.overflow, overflow, "#%d", i) } } + +func TestSafeConvert(t *testing.T) { + testCases := []struct { + from interface{} + want interface{} + err bool + }{ + {int(0), int64(0), false}, + {int(math.MaxInt), int64(math.MaxInt), false}, + {int(math.MinInt), int64(math.MinInt), false}, + {uint(0), uint64(0), false}, + {uint(math.MaxUint), uint64(math.MaxUint), false}, + {int64(0), uint64(0), false}, + {int64(math.MaxInt64), uint64(math.MaxInt64), false}, + {int64(math.MinInt64), uint64(0), true}, + {uint64(math.MaxUint64), int64(0), true}, + {uint64(math.MaxInt64), int64(math.MaxInt64), false}, + {int32(-1), uint32(0), true}, + {int32(0), uint32(0), false}, + {int32(math.MaxInt32), uint32(math.MaxInt32), false}, + {int32(math.MaxInt32), int16(0), true}, + {int32(math.MinInt32), int16(0), true}, + {int32(0), int16(0), false}, + {uint32(math.MaxUint32), int32(0), true}, + {uint32(math.MaxInt32), int32(math.MaxInt32), false}, + {uint32(0), int32(0), false}, + {int16(0), uint32(0), false}, + {int16(-1), uint32(0), true}, + {int16(math.MaxInt16), uint32(math.MaxInt16), false}, + {int64(math.MinInt16), int16(math.MinInt16), false}, + {int64(math.MinInt16 - 1), int16(0), true}, + {int32(math.MinInt16), int16(math.MinInt16), false}, + {int32(math.MinInt16 - 1), int16(0), true}, + {int32(math.MinInt16 + 1), int16(math.MinInt16 + 1), false}, + {int32(math.MaxInt16), int16(math.MaxInt16), false}, + {int32(math.MaxInt16 + 1), int16(0), true}, + {int32(math.MaxInt16 - 1), int16(math.MaxInt16 - 1), false}, + } + + for i, tc := range testCases { + testName := fmt.Sprintf("%d:%T(%d)-%T(%d)", i, tc.from, tc.from, tc.want, tc.want) + t.Run(testName, func(t *testing.T) { + var result interface{} + var err error + + switch from := tc.from.(type) { + case int: + switch tc.want.(type) { + case int64: + result, err = SafeConvert[int, int64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case uint: + switch tc.want.(type) { + case uint64: + result, err = SafeConvert[uint, uint64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case int64: + switch tc.want.(type) { + case uint64: + result, err = SafeConvert[int64, uint64](from) + case int64: + result, err = SafeConvert[int64, int64](from) + case uint16: + result, err = SafeConvert[int64, uint16](from) + case int16: + result, err = SafeConvert[int64, int16](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case uint64: + switch tc.want.(type) { + case int64: + result, err = SafeConvert[uint64, int64](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case int32: + switch tc.want.(type) { + case int16: + result, err = SafeConvert[int32, int16](from) + case uint32: + result, err = SafeConvert[int32, uint32](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case uint32: + switch tc.want.(type) { + case int16: + result, err = SafeConvert[uint32, int16](from) + case int32: + result, err = SafeConvert[uint32, int32](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + case int16: + switch tc.want.(type) { + case int32: + result, err = SafeConvert[int16, int32](from) + case uint32: + result, err = SafeConvert[int16, uint32](from) + default: + t.Fatalf("test case %d: unsupported target type %T", i, tc.want) + } + default: + t.Fatalf("test case %d: unsupported source type %T", i, tc.from) + } + + if (err != nil) != tc.err { + t.Errorf("test case %d: expected error %v, got %v", i, tc.err, err) + } + if err == nil && result != tc.want { + t.Errorf("test case %d: expected result %v, got %v", i, tc.want, result) + } + }) + } +} + +func TestMustConvertPanics(t *testing.T) { + assert.NotPanics(t, func() { MustConvert[int32, int32](0) }) + assert.Panics(t, func() { MustConvert[int32, int16](math.MaxInt16 + 1) }) + assert.NotPanics(t, func() { MustConvert[int32, int16](math.MaxInt16) }) +} diff --git a/test/e2e/networks/rotate.toml b/test/e2e/networks/rotate.toml index 2e725ec190..30f6943097 100644 --- a/test/e2e/networks/rotate.toml +++ b/test/e2e/networks/rotate.toml @@ -152,7 +152,7 @@ start_at = 1005 # Becomes part of the validator set at 1030 to ensure ther seeds = ["seed01"] snapshot_interval = 5 block_sync = "v0" -#state_sync = "p2p" +state_sync = "p2p" #persistent_peers = ["validator01", "validator02", "validator03", "validator04", "validator05", "validator07", "validator08"] perturb = ["pause", "disconnect", "restart"] privval_protocol = "dashcore" @@ -192,7 +192,7 @@ privval_protocol = "dashcore" start_at = 1030 mode = "full" block_sync = "v0" -#state_sync = "rpc" +state_sync = "p2p" persistent_peers = [ "validator01", "validator02", diff --git a/test/e2e/pkg/mockcoreserver/core_server.go b/test/e2e/pkg/mockcoreserver/core_server.go index 4b6025e125..7b51d1a8ad 100644 --- a/test/e2e/pkg/mockcoreserver/core_server.go +++ b/test/e2e/pkg/mockcoreserver/core_server.go @@ -9,6 +9,7 @@ import ( "github.com/dashpay/dashd-go/btcjson" "github.com/dashpay/tenderdash/crypto" + "github.com/dashpay/tenderdash/libs/math" "github.com/dashpay/tenderdash/privval" "github.com/dashpay/tenderdash/types" ) @@ -68,8 +69,9 @@ func (c *MockCoreServer) QuorumInfo(ctx context.Context, cmd btcjson.QuorumCmd) if err != nil { panic(err) } + return btcjson.QuorumInfoResult{ - Height: uint32(height), + Height: math.MustConvertUint32(height), Type: strconv.Itoa(int(c.LLMQType)), QuorumHash: quorumHash.String(), Members: members, @@ -146,9 +148,7 @@ func (c *MockCoreServer) QuorumVerify(ctx context.Context, cmd btcjson.QuorumCmd signatureVerified := thresholdPublicKey.VerifySignatureDigest(signID, signature) - res := btcjson.QuorumVerifyResult{ - Result: signatureVerified, - } + res := btcjson.QuorumVerifyResult{Result: signatureVerified} return res } diff --git a/test/e2e/pkg/mockcoreserver/methods.go b/test/e2e/pkg/mockcoreserver/methods.go index 5894dfa867..b38f17982b 100644 --- a/test/e2e/pkg/mockcoreserver/methods.go +++ b/test/e2e/pkg/mockcoreserver/methods.go @@ -64,8 +64,8 @@ func WithQuorumVerifyMethod(cs CoreServer, times int) MethodFunc { &cmd.LLMQType, &cmd.RequestID, &cmd.MessageHash, - &cmd.QuorumHash, &cmd.Signature, + &cmd.QuorumHash, ) if err != nil { return nil, err diff --git a/test/e2e/runner/setup.go b/test/e2e/runner/setup.go index ad14e43064..b03046fee7 100644 --- a/test/e2e/runner/setup.go +++ b/test/e2e/runner/setup.go @@ -9,7 +9,6 @@ import ( "fmt" "os" "path/filepath" - "regexp" "sort" "strconv" "strings" @@ -458,23 +457,6 @@ func MakeAppConfig(node *e2e.Node) ([]byte, error) { return buf.Bytes(), nil } -// UpdateConfigStateSync updates the state sync config for a node. -func UpdateConfigStateSync(node *e2e.Node, height int64, hash []byte) error { - cfgPath := filepath.Join(node.Testnet.Dir, node.Name, "config", "config.toml") - - // FIXME Apparently there's no function to simply load a config file without - // involving the entire Viper apparatus, so we'll just resort to regexps. - bz, err := os.ReadFile(cfgPath) - if err != nil { - return err - } - bz = regexp.MustCompile(`(?m)^trust-height =.*`).ReplaceAll(bz, []byte(fmt.Sprintf(`trust-height = %v`, height-1))) - bz = regexp.MustCompile(`(?m)^trust-hash =.*`).ReplaceAll(bz, []byte(fmt.Sprintf(`trust-hash = "%X"`, hash))) - //nolint: gosec - // G306: Expect WriteFile permissions to be 0600 or less - return os.WriteFile(cfgPath, bz, 0644) -} - func newDefaultFilePV(node *e2e.Node, nodeDir string) (*privval.FilePV, error) { return privval.NewFilePVWithOptions( privval.WithPrivateKeysMap(node.PrivvalKeys), diff --git a/test/e2e/runner/start.go b/test/e2e/runner/start.go index 3b45597fb9..64c5b0ce8f 100644 --- a/test/e2e/runner/start.go +++ b/test/e2e/runner/start.go @@ -70,11 +70,6 @@ func Start(ctx context.Context, logger log.Logger, testnet *e2e.Testnet, ti infr "nodes", len(testnet.Nodes)-len(nodeQueue), "pending", len(nodeQueue)) - block, blockID, err := waitForHeight(ctx, testnet, networkHeight) - if err != nil { - return err - } - for _, node := range nodeQueue { if node.StartAt > networkHeight { // if we're starting a node that's ahead of @@ -93,16 +88,7 @@ func Start(ctx context.Context, logger log.Logger, testnet *e2e.Testnet, ti infr networkHeight = node.StartAt - block, blockID, err = waitForHeight(ctx, testnet, networkHeight) - if err != nil { - return err - } - } - - // Update any state sync nodes with a trusted height and hash - if node.StateSync != e2e.StateSyncDisabled || node.Mode == e2e.ModeLight { - err = UpdateConfigStateSync(node, block.Height, blockID.Hash.Bytes()) - if err != nil { + if _, _, err := waitForHeight(ctx, testnet, networkHeight); err != nil { return err } } diff --git a/types/quorum_sign_data.go b/types/quorum_sign_data.go index 1bd51c6224..2279b47cc9 100644 --- a/types/quorum_sign_data.go +++ b/types/quorum_sign_data.go @@ -11,6 +11,7 @@ import ( "github.com/dashpay/tenderdash/crypto" tmbytes "github.com/dashpay/tenderdash/libs/bytes" + tmmath "github.com/dashpay/tenderdash/libs/math" "github.com/dashpay/tenderdash/proto/tendermint/types" ) @@ -161,7 +162,7 @@ func (i *SignItem) Validate() error { if len(i.MsgHash) != crypto.DefaultHashSize { return fmt.Errorf("invalid hash size %d: %X", len(i.MsgHash), i.MsgHash) } - if len(i.QuorumHash) != crypto.DefaultHashSize { + if len(i.QuorumHash) != crypto.QuorumHashSize { return fmt.Errorf("invalid quorum hash size %d: %X", len(i.QuorumHash), i.QuorumHash) } // Msg is optional @@ -179,7 +180,7 @@ func (i SignItem) MarshalZerologObject(e *zerolog.Event) { e.Hex("signID", i.SignHash) e.Hex("msgHash", i.MsgHash) e.Hex("quorumHash", i.QuorumHash) - e.Uint8("llmqType", uint8(i.LlmqType)) + e.Uint8("llmqType", tmmath.MustConvertUint8(i.LlmqType)) } @@ -247,7 +248,7 @@ func (i *SignItem) UpdateSignHash(reverse bool) { // fmt.Printf("RequestID: %x + ", blsRequestID) // fmt.Printf("MsgHash: %x\n", blsMessageHash) - blsSignHash := bls.BuildSignHash(uint8(llmqType), blsQuorumHash, blsRequestID, blsMessageHash) + blsSignHash := bls.BuildSignHash(tmmath.MustConvertUint8(llmqType), blsQuorumHash, blsRequestID, blsMessageHash) signHash := make([]byte, 32) copy(signHash, blsSignHash[:]) diff --git a/types/validator_test.go b/types/validator_test.go index 72d9051100..5f052f77c3 100644 --- a/types/validator_test.go +++ b/types/validator_test.go @@ -2,12 +2,15 @@ package types import ( "context" + "encoding/base64" + "encoding/hex" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dashpay/tenderdash/crypto" + "github.com/dashpay/tenderdash/crypto/bls12381" ) func TestValidatorProtoBuf(t *testing.T) { @@ -111,3 +114,22 @@ func TestValidatorValidateBasic(t *testing.T) { }) } } + +// TestValidatorSetHashVectors checks if provided validator threshold pubkey and quorum hash returns expected hash +func TestValidatorSetHashVectors(t *testing.T) { + thresholdPublicKey, err := base64.RawStdEncoding.DecodeString("gw5F5F5kFNnWFUc8woFOaxccUI+cd+ixaSS3RZT2HJlWpvoWM16YRn6sjYvbdtGH") + require.NoError(t, err) + + quorumHash, err := hex.DecodeString("703ee5bfc78765cc9e151d8dd84e30e196ababa83ac6cbdee31a88a46bba81b9") + require.NoError(t, err) + + expected := "81742F95E99EAE96ABC727FE792CECB4996205DE6BFC88AFEE1F60B96BC648B2" + + valset := ValidatorSet{ + ThresholdPublicKey: bls12381.PubKey(thresholdPublicKey), + QuorumHash: quorumHash, + } + + hash := valset.Hash() + assert.Equal(t, expected, hash.String()) +}